[core] codegen/ndarray: Implement subscript assignment

Based on 5bed394e: core/ndstrides: implement subscript assignment

Overlapping is not handled. Currently it has undefined behavior.
This commit is contained in:
David Mak 2024-12-18 16:44:57 +08:00
parent 8d8f9e9b2a
commit 5c8a1d5f2f
2 changed files with 88 additions and 2 deletions

View File

@ -16,7 +16,11 @@ use super::{
gen_in_range_check, gen_in_range_check,
irrt::{handle_slice_indices, list_slice_assignment}, irrt::{handle_slice_indices, list_slice_assignment},
macros::codegen_unreachable, macros::codegen_unreachable,
values::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue}, types::ndarray::NDArrayType,
values::{
ndarray::{RustNDIndex, ScalarOrNDArray},
ArrayLikeIndexer, ArraySliceValue, ListValue, ProxyValue, RangeValue,
},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
use crate::{ use crate::{
@ -411,7 +415,56 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{ {
// Handle NDArray item assignment // Handle NDArray item assignment
todo!("ndarray subscript assignment is not yet implemented"); // Process target
let target = generator
.gen_expr(ctx, target)?
.unwrap()
.to_basic_value_enum(ctx, generator, target_ty)?;
// Process key
let key = RustNDIndex::from_subscript_expr(generator, ctx, key)?;
// Process value
let value = value.to_basic_value_enum(ctx, generator, value_ty)?;
/*
Reference code:
```python
target = target[key]
value = np.asarray(value)
shape = np.broadcast_shape((target, value))
target = np.broadcast_to(target, shape)
value = np.broadcast_to(value, shape)
...and finally copy 1-1 from value to target.
```
*/
let target = NDArrayType::from_unifier_type(generator, ctx, target_ty)
.map_value(target.into_pointer_value(), None);
let target = target.index(generator, ctx, &key);
let value = ScalarOrNDArray::from_value(generator, ctx, (value_ty, value))
.to_ndarray(generator, ctx);
let broadcast_ndims = [target.get_type().ndims(), value.get_type().ndims()]
.iter()
.filter_map(|ndims| *ndims)
.max();
let broadcast_result = NDArrayType::new(
generator,
ctx.ctx,
value.get_type().element_type(),
broadcast_ndims,
)
.broadcast(generator, ctx, &[target, value]);
let target = broadcast_result.ndarrays[0];
let value = broadcast_result.ndarrays[1];
target.copy_data_from(generator, ctx, value);
} }
_ => { _ => {
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty)); panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));

View File

@ -276,6 +276,38 @@ def test_ndarray_broadcast_to():
output_int32(np_shape(zs)[2]) output_int32(np_shape(zs)[2])
output_ndarray_float_3(zs) output_ndarray_float_3(zs)
def test_ndarray_subscript_assignment():
xs = np_array([[11.0, 22.0, 33.0, 44.0], [55.0, 66.0, 77.0, 88.0]])
xs[0, 0] = 99.0
output_ndarray_float_2(xs)
xs[0] = 100.0
output_ndarray_float_2(xs)
xs[:, ::2] = 101.0
output_ndarray_float_2(xs)
xs[1:, 0] = 102.0
output_ndarray_float_2(xs)
xs[0] = np_array([-1.0, -2.0, -3.0, -4.0])
output_ndarray_float_2(xs)
xs[:] = np_array([-5.0, -6.0, -7.0, -8.0])
output_ndarray_float_2(xs)
# Test assignment with memory sharing
ys1 = np_reshape(xs, (2, 4))
ys2 = np_transpose(ys1)
ys3 = ys2[::-1, 0]
ys3[0] = -999.0
output_ndarray_float_2(xs)
output_ndarray_float_2(ys1)
output_ndarray_float_2(ys2)
output_ndarray_float_1(ys3)
def test_ndarray_add(): def test_ndarray_add():
x = np_identity(2) x = np_identity(2)
y = x + np_ones([2, 2]) y = x + np_ones([2, 2])
@ -1653,6 +1685,7 @@ def run() -> int32:
test_ndarray_transpose() test_ndarray_transpose()
test_ndarray_reshape() test_ndarray_reshape()
test_ndarray_broadcast_to() test_ndarray_broadcast_to()
test_ndarray_subscript_assignment()
test_ndarray_add() test_ndarray_add()
test_ndarray_add_broadcast() test_ndarray_add_broadcast()