forked from M-Labs/nac3
[core] codegen/ndarray: Implement subscript assignment
Based on 5bed394e: core/ndstrides: implement subscript assignment Overlapping is not handled. Currently it has undefined behavior.
This commit is contained in:
parent
dcde1d9c87
commit
2dc5e79a23
@ -16,7 +16,11 @@ use super::{
|
||||
gen_in_range_check,
|
||||
irrt::{handle_slice_indices, list_slice_assignment},
|
||||
macros::codegen_unreachable,
|
||||
values::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
|
||||
types::ndarray::NDArrayType,
|
||||
values::{
|
||||
ndarray::{RustNDIndex, ScalarOrNDArray},
|
||||
ArrayLikeIndexer, ArraySliceValue, ListValue, ProxyValue, RangeValue,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
use crate::{
|
||||
@ -411,7 +415,54 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// Handle NDArray item assignment
|
||||
todo!("ndarray subscript assignment is not yet implemented");
|
||||
// Process target
|
||||
let target = generator
|
||||
.gen_expr(ctx, target)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, target_ty)?;
|
||||
|
||||
// Process key
|
||||
let key = RustNDIndex::from_subscript_expr(generator, ctx, key)?;
|
||||
|
||||
// Process value
|
||||
let value = value.to_basic_value_enum(ctx, generator, value_ty)?;
|
||||
|
||||
// Reference code:
|
||||
// ```python
|
||||
// target = target[key]
|
||||
// value = np.asarray(value)
|
||||
//
|
||||
// shape = np.broadcast_shape((target, value))
|
||||
//
|
||||
// target = np.broadcast_to(target, shape)
|
||||
// value = np.broadcast_to(value, shape)
|
||||
//
|
||||
// # ...and finally copy 1-1 from value to target.
|
||||
// ```
|
||||
|
||||
let target = NDArrayType::from_unifier_type(generator, ctx, target_ty)
|
||||
.map_value(target.into_pointer_value(), None);
|
||||
let target = target.index(generator, ctx, &key);
|
||||
|
||||
let value = ScalarOrNDArray::from_value(generator, ctx, (value_ty, value))
|
||||
.to_ndarray(generator, ctx);
|
||||
|
||||
let broadcast_ndims = [target.get_type().ndims(), value.get_type().ndims()]
|
||||
.iter()
|
||||
.filter_map(|ndims| *ndims)
|
||||
.max();
|
||||
let broadcast_result = NDArrayType::new(
|
||||
generator,
|
||||
ctx.ctx,
|
||||
value.get_type().element_type(),
|
||||
broadcast_ndims,
|
||||
)
|
||||
.broadcast(generator, ctx, &[target, value]);
|
||||
|
||||
let target = broadcast_result.ndarrays[0];
|
||||
let value = broadcast_result.ndarrays[1];
|
||||
|
||||
target.copy_data_from(generator, ctx, value);
|
||||
}
|
||||
_ => {
|
||||
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));
|
||||
|
@ -276,6 +276,38 @@ def test_ndarray_broadcast_to():
|
||||
output_int32(np_shape(zs)[2])
|
||||
output_ndarray_float_3(zs)
|
||||
|
||||
def test_ndarray_subscript_assignment():
|
||||
xs = np_array([[11.0, 22.0, 33.0, 44.0], [55.0, 66.0, 77.0, 88.0]])
|
||||
|
||||
xs[0, 0] = 99.0
|
||||
output_ndarray_float_2(xs)
|
||||
|
||||
xs[0] = 100.0
|
||||
output_ndarray_float_2(xs)
|
||||
|
||||
xs[:, ::2] = 101.0
|
||||
output_ndarray_float_2(xs)
|
||||
|
||||
xs[1:, 0] = 102.0
|
||||
output_ndarray_float_2(xs)
|
||||
|
||||
xs[0] = np_array([-1.0, -2.0, -3.0, -4.0])
|
||||
output_ndarray_float_2(xs)
|
||||
|
||||
xs[:] = np_array([-5.0, -6.0, -7.0, -8.0])
|
||||
output_ndarray_float_2(xs)
|
||||
|
||||
# Test assignment with memory sharing
|
||||
ys1 = np_reshape(xs, (2, 4))
|
||||
ys2 = np_transpose(ys1)
|
||||
ys3 = ys2[::-1, 0]
|
||||
ys3[0] = -999.0
|
||||
|
||||
output_ndarray_float_2(xs)
|
||||
output_ndarray_float_2(ys1)
|
||||
output_ndarray_float_2(ys2)
|
||||
output_ndarray_float_1(ys3)
|
||||
|
||||
def test_ndarray_add():
|
||||
x = np_identity(2)
|
||||
y = x + np_ones([2, 2])
|
||||
@ -1653,6 +1685,7 @@ def run() -> int32:
|
||||
test_ndarray_transpose()
|
||||
test_ndarray_reshape()
|
||||
test_ndarray_broadcast_to()
|
||||
test_ndarray_subscript_assignment()
|
||||
|
||||
test_ndarray_add()
|
||||
test_ndarray_add_broadcast()
|
||||
|
Loading…
Reference in New Issue
Block a user