forked from M-Labs/nac3
[core] codegen/ndarray: Implement subscript assignment
Based on 5bed394e
: core/ndstrides: implement subscript assignment
Overlapping is not handled. Currently it has undefined behavior.
This commit is contained in:
parent
dcde1d9c87
commit
2dc5e79a23
@ -16,7 +16,11 @@ use super::{
|
|||||||
gen_in_range_check,
|
gen_in_range_check,
|
||||||
irrt::{handle_slice_indices, list_slice_assignment},
|
irrt::{handle_slice_indices, list_slice_assignment},
|
||||||
macros::codegen_unreachable,
|
macros::codegen_unreachable,
|
||||||
values::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
|
types::ndarray::NDArrayType,
|
||||||
|
values::{
|
||||||
|
ndarray::{RustNDIndex, ScalarOrNDArray},
|
||||||
|
ArrayLikeIndexer, ArraySliceValue, ListValue, ProxyValue, RangeValue,
|
||||||
|
},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
@ -411,7 +415,54 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
|||||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
// Handle NDArray item assignment
|
// Handle NDArray item assignment
|
||||||
todo!("ndarray subscript assignment is not yet implemented");
|
// Process target
|
||||||
|
let target = generator
|
||||||
|
.gen_expr(ctx, target)?
|
||||||
|
.unwrap()
|
||||||
|
.to_basic_value_enum(ctx, generator, target_ty)?;
|
||||||
|
|
||||||
|
// Process key
|
||||||
|
let key = RustNDIndex::from_subscript_expr(generator, ctx, key)?;
|
||||||
|
|
||||||
|
// Process value
|
||||||
|
let value = value.to_basic_value_enum(ctx, generator, value_ty)?;
|
||||||
|
|
||||||
|
// Reference code:
|
||||||
|
// ```python
|
||||||
|
// target = target[key]
|
||||||
|
// value = np.asarray(value)
|
||||||
|
//
|
||||||
|
// shape = np.broadcast_shape((target, value))
|
||||||
|
//
|
||||||
|
// target = np.broadcast_to(target, shape)
|
||||||
|
// value = np.broadcast_to(value, shape)
|
||||||
|
//
|
||||||
|
// # ...and finally copy 1-1 from value to target.
|
||||||
|
// ```
|
||||||
|
|
||||||
|
let target = NDArrayType::from_unifier_type(generator, ctx, target_ty)
|
||||||
|
.map_value(target.into_pointer_value(), None);
|
||||||
|
let target = target.index(generator, ctx, &key);
|
||||||
|
|
||||||
|
let value = ScalarOrNDArray::from_value(generator, ctx, (value_ty, value))
|
||||||
|
.to_ndarray(generator, ctx);
|
||||||
|
|
||||||
|
let broadcast_ndims = [target.get_type().ndims(), value.get_type().ndims()]
|
||||||
|
.iter()
|
||||||
|
.filter_map(|ndims| *ndims)
|
||||||
|
.max();
|
||||||
|
let broadcast_result = NDArrayType::new(
|
||||||
|
generator,
|
||||||
|
ctx.ctx,
|
||||||
|
value.get_type().element_type(),
|
||||||
|
broadcast_ndims,
|
||||||
|
)
|
||||||
|
.broadcast(generator, ctx, &[target, value]);
|
||||||
|
|
||||||
|
let target = broadcast_result.ndarrays[0];
|
||||||
|
let value = broadcast_result.ndarrays[1];
|
||||||
|
|
||||||
|
target.copy_data_from(generator, ctx, value);
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));
|
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));
|
||||||
|
@ -276,6 +276,38 @@ def test_ndarray_broadcast_to():
|
|||||||
output_int32(np_shape(zs)[2])
|
output_int32(np_shape(zs)[2])
|
||||||
output_ndarray_float_3(zs)
|
output_ndarray_float_3(zs)
|
||||||
|
|
||||||
|
def test_ndarray_subscript_assignment():
|
||||||
|
xs = np_array([[11.0, 22.0, 33.0, 44.0], [55.0, 66.0, 77.0, 88.0]])
|
||||||
|
|
||||||
|
xs[0, 0] = 99.0
|
||||||
|
output_ndarray_float_2(xs)
|
||||||
|
|
||||||
|
xs[0] = 100.0
|
||||||
|
output_ndarray_float_2(xs)
|
||||||
|
|
||||||
|
xs[:, ::2] = 101.0
|
||||||
|
output_ndarray_float_2(xs)
|
||||||
|
|
||||||
|
xs[1:, 0] = 102.0
|
||||||
|
output_ndarray_float_2(xs)
|
||||||
|
|
||||||
|
xs[0] = np_array([-1.0, -2.0, -3.0, -4.0])
|
||||||
|
output_ndarray_float_2(xs)
|
||||||
|
|
||||||
|
xs[:] = np_array([-5.0, -6.0, -7.0, -8.0])
|
||||||
|
output_ndarray_float_2(xs)
|
||||||
|
|
||||||
|
# Test assignment with memory sharing
|
||||||
|
ys1 = np_reshape(xs, (2, 4))
|
||||||
|
ys2 = np_transpose(ys1)
|
||||||
|
ys3 = ys2[::-1, 0]
|
||||||
|
ys3[0] = -999.0
|
||||||
|
|
||||||
|
output_ndarray_float_2(xs)
|
||||||
|
output_ndarray_float_2(ys1)
|
||||||
|
output_ndarray_float_2(ys2)
|
||||||
|
output_ndarray_float_1(ys3)
|
||||||
|
|
||||||
def test_ndarray_add():
|
def test_ndarray_add():
|
||||||
x = np_identity(2)
|
x = np_identity(2)
|
||||||
y = x + np_ones([2, 2])
|
y = x + np_ones([2, 2])
|
||||||
@ -1653,6 +1685,7 @@ def run() -> int32:
|
|||||||
test_ndarray_transpose()
|
test_ndarray_transpose()
|
||||||
test_ndarray_reshape()
|
test_ndarray_reshape()
|
||||||
test_ndarray_broadcast_to()
|
test_ndarray_broadcast_to()
|
||||||
|
test_ndarray_subscript_assignment()
|
||||||
|
|
||||||
test_ndarray_add()
|
test_ndarray_add()
|
||||||
test_ndarray_add_broadcast()
|
test_ndarray_add_broadcast()
|
||||||
|
Loading…
Reference in New Issue
Block a user