ndstrides: [9] Implement ndarray subscript assignment #519
|
@ -18,7 +18,7 @@ use crate::{
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
|
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
|
||||||
typecheck::typedef::Type,
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub mod array;
|
pub mod array;
|
||||||
|
@ -483,6 +483,22 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
|
||||||
TupleObject::from_objects(generator, ctx, objects)
|
TupleObject::from_objects(generator, ctx, objects)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create an unsized ndarray to contain `object`.
|
||||||
|
pub fn make_unsized<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
object: AnyObject<'ctx>,
|
||||||
|
) -> NDArrayObject<'ctx> {
|
||||||
|
// We have to put the value on the stack to get a data pointer.
|
||||||
|
let data = ctx.builder.build_alloca(object.value.get_type(), "make_unsized").unwrap();
|
||||||
|
ctx.builder.build_store(data, object.value).unwrap();
|
||||||
|
let data = Ptr(Int(Byte)).pointer_cast(generator, ctx, data);
|
||||||
|
|
||||||
|
let ndarray = NDArrayObject::alloca(generator, ctx, object.ty, 0);
|
||||||
|
ndarray.instance.set(ctx, |f| f.data, data);
|
||||||
|
ndarray
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.
|
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.
|
||||||
|
@ -492,7 +508,50 @@ pub enum ScalarOrNDArray<'ctx> {
|
||||||
NDArray(NDArrayObject<'ctx>),
|
NDArray(NDArrayObject<'ctx>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for AnyObject<'ctx> {
|
||||||
|
type Error = ();
|
||||||
|
|
||||||
|
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
|
||||||
|
match value {
|
||||||
|
ScalarOrNDArray::Scalar(scalar) => Ok(*scalar),
|
||||||
|
ScalarOrNDArray::NDArray(_ndarray) => Err(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> {
|
||||||
|
type Error = ();
|
||||||
|
|
||||||
|
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
|
||||||
|
match value {
|
||||||
|
ScalarOrNDArray::Scalar(_scalar) => Err(()),
|
||||||
|
ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<'ctx> ScalarOrNDArray<'ctx> {
|
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
|
/// Split on `object` either into a scalar or an ndarray.
|
||||||
|
///
|
||||||
|
/// If `object` is an ndarray, [`ScalarOrNDArray::NDArray`].
|
||||||
|
///
|
||||||
|
/// For everything else, it is wrapped with [`ScalarOrNDArray::Scalar`].
|
||||||
|
pub fn split_object<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
object: AnyObject<'ctx>,
|
||||||
|
) -> ScalarOrNDArray<'ctx> {
|
||||||
|
match &*ctx.unifier.get_ty(object.ty) {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
let ndarray = NDArrayObject::from_object(generator, ctx, object);
|
||||||
|
ScalarOrNDArray::NDArray(ndarray)
|
||||||
|
}
|
||||||
|
_ => ScalarOrNDArray::Scalar(object),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
|
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
|
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
|
||||||
|
@ -501,4 +560,27 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(),
|
ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`.
|
||||||
|
/// - If this is an ndarray, the ndarray is returned.
|
||||||
|
/// - If this is a scalar, this function returns new ndarray created with [`NDArrayObject::make_unsized`].
|
||||||
|
pub fn to_ndarray<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> NDArrayObject<'ctx> {
|
||||||
|
match self {
|
||||||
|
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
|
||||||
|
ScalarOrNDArray::Scalar(scalar) => NDArrayObject::make_unsized(generator, ctx, *scalar),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the dtype of the ndarray created if this were called with [`ScalarOrNDArray::to_ndarray`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_dtype(&self) -> Type {
|
||||||
|
match self {
|
||||||
|
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
|
||||||
|
ScalarOrNDArray::Scalar(scalar) => scalar.ty,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,12 @@ use super::{
|
||||||
gen_in_range_check,
|
gen_in_range_check,
|
||||||
irrt::{handle_slice_indices, list_slice_assignment},
|
irrt::{handle_slice_indices, list_slice_assignment},
|
||||||
macros::codegen_unreachable,
|
macros::codegen_unreachable,
|
||||||
|
object::{
|
||||||
|
any::AnyObject,
|
||||||
|
ndarray::{
|
||||||
|
indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject, ScalarOrNDArray,
|
||||||
|
},
|
||||||
|
},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -411,7 +417,47 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
// Handle NDArray item assignment
|
// Handle NDArray item assignment
|
||||||
todo!("ndarray subscript assignment is not yet implemented");
|
// Process target
|
||||||
|
let target = generator
|
||||||
|
.gen_expr(ctx, target)?
|
||||||
|
.unwrap()
|
||||||
|
.to_basic_value_enum(ctx, generator, target_ty)?;
|
||||||
|
let target = AnyObject { value: target, ty: target_ty };
|
||||||
|
|
||||||
|
// Process key
|
||||||
|
let key = gen_ndarray_subscript_ndindices(generator, ctx, key)?;
|
||||||
|
|
||||||
|
// Process value
|
||||||
|
let value = value.to_basic_value_enum(ctx, generator, value_ty)?;
|
||||||
|
let value = AnyObject { value, ty: value_ty };
|
||||||
|
|
||||||
|
/*
|
||||||
|
Reference code:
|
||||||
|
```python
|
||||||
|
target = target[key]
|
||||||
|
value = np.asarray(value)
|
||||||
|
|
||||||
|
shape = np.broadcast_shape((target, value))
|
||||||
|
|
||||||
|
target = np.broadcast_to(target, shape)
|
||||||
|
value = np.broadcast_to(value, shape)
|
||||||
|
|
||||||
|
...and finally copy 1-1 from value to target.
|
||||||
|
```
|
||||||
|
*/
|
||||||
|
|
||||||
|
let target = NDArrayObject::from_object(generator, ctx, target);
|
||||||
|
let target = target.index(generator, ctx, &key);
|
||||||
|
|
||||||
|
let value =
|
||||||
|
ScalarOrNDArray::split_object(generator, ctx, value).to_ndarray(generator, ctx);
|
||||||
|
|
||||||
|
let broadcast_result = NDArrayObject::broadcast(generator, ctx, &[target, value]);
|
||||||
|
|
||||||
|
let target = broadcast_result.ndarrays[0];
|
||||||
|
let value = broadcast_result.ndarrays[1];
|
||||||
|
|
||||||
|
target.copy_data_from(generator, ctx, value);
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));
|
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));
|
||||||
|
|
Loading…
Reference in New Issue