[core] codegen: Implement ScalarOrNDArray and use it in indexing

Based on 8f9d2d82: core/ndstrides: implement ndarray indexing.
This commit is contained in:
David Mak 2024-12-10 16:43:57 +08:00
parent 438943ac6f
commit dc91d9e35a
2 changed files with 59 additions and 11 deletions

View File

@ -34,7 +34,8 @@ use super::{
},
types::{ndarray::NDArrayType, ListType},
values::{
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue,
ndarray::{NDArrayValue, RustNDIndex},
ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue,
TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
},
CodeGenContext, CodeGenTask, CodeGenerator,
@ -3486,19 +3487,21 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
}
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) =
unpack_ndarray_var_tys(&mut ctx.unifier, value.custom.unwrap());
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
.into_pointer_value()
} else {
let Some(ndarray) = generator.gen_expr(ctx, value)? else {
return Ok(None);
};
let v = NDArrayType::from_unifier_type(generator, ctx, value.custom.unwrap())
.map_value(v, None);
return gen_ndarray_subscript_expr(generator, ctx, ty, ndims, v, slice);
let ndarray_ty = value.custom.unwrap();
let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty)
.map_value(ndarray.into_pointer_value(), None);
let indices = RustNDIndex::from_subscript_expr(generator, ctx, slice)?;
let result = ndarray
.index(generator, ctx, &indices)
.split_unsized(generator, ctx)
.to_basic_value_enum();
return Ok(Some(ValueEnum::Dynamic(result)));
}
TypeEnum::TTuple { .. } => {
let index: u32 =

View File

@ -403,6 +403,33 @@ impl<'ctx> NDArrayValue<'ctx> {
assert_eq!(self.dtype, src.dtype, "self and src dtype should match");
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
}
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
#[must_use]
pub fn is_unsized(&self) -> Option<bool> {
self.ndims.map(|ndims| ndims == 0)
}
/// If this ndarray is unsized, return its sole value as an [`AnyObject`].
/// Otherwise, do nothing and return the ndarray itself.
// TODO: Rename to get_unsized_element
pub fn split_unsized<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> ScalarOrNDArray<'ctx> {
let Some(is_unsized) = self.is_unsized() else { todo!() };
if is_unsized {
// NOTE: `np.size(self) == 0` here is never possible.
let zero = generator.get_size_type(ctx.ctx).const_zero();
let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) };
ScalarOrNDArray::Scalar(value)
} else {
ScalarOrNDArray::NDArray(*self)
}
}
}
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
@ -884,3 +911,21 @@ pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec<
}
strides
}
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.
#[derive(Clone, Copy)]
pub enum ScalarOrNDArray<'ctx> {
Scalar(BasicValueEnum<'ctx>),
NDArray(NDArrayValue<'ctx>),
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
#[must_use]
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
match self {
ScalarOrNDArray::Scalar(scalar) => scalar,
ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(),
}
}
}