forked from M-Labs/nac3
[core] codegen: Implement ScalarOrNDArray and use it in indexing
Based on 8f9d2d82: core/ndstrides: implement ndarray indexing.
This commit is contained in:
parent
438943ac6f
commit
dc91d9e35a
@ -34,7 +34,8 @@ use super::{
|
||||
},
|
||||
types::{ndarray::NDArrayType, ListType},
|
||||
values::{
|
||||
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue,
|
||||
ndarray::{NDArrayValue, RustNDIndex},
|
||||
ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue,
|
||||
TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
|
||||
},
|
||||
CodeGenContext, CodeGenTask, CodeGenerator,
|
||||
@ -3486,19 +3487,21 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
}
|
||||
}
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let (ty, ndims) =
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, value.custom.unwrap());
|
||||
|
||||
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
|
||||
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
|
||||
.into_pointer_value()
|
||||
} else {
|
||||
let Some(ndarray) = generator.gen_expr(ctx, value)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
let v = NDArrayType::from_unifier_type(generator, ctx, value.custom.unwrap())
|
||||
.map_value(v, None);
|
||||
|
||||
return gen_ndarray_subscript_expr(generator, ctx, ty, ndims, v, slice);
|
||||
let ndarray_ty = value.custom.unwrap();
|
||||
let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||
let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty)
|
||||
.map_value(ndarray.into_pointer_value(), None);
|
||||
|
||||
let indices = RustNDIndex::from_subscript_expr(generator, ctx, slice)?;
|
||||
let result = ndarray
|
||||
.index(generator, ctx, &indices)
|
||||
.split_unsized(generator, ctx)
|
||||
.to_basic_value_enum();
|
||||
return Ok(Some(ValueEnum::Dynamic(result)));
|
||||
}
|
||||
TypeEnum::TTuple { .. } => {
|
||||
let index: u32 =
|
||||
|
@ -403,6 +403,33 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
assert_eq!(self.dtype, src.dtype, "self and src dtype should match");
|
||||
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
|
||||
}
|
||||
|
||||
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
|
||||
#[must_use]
|
||||
pub fn is_unsized(&self) -> Option<bool> {
|
||||
self.ndims.map(|ndims| ndims == 0)
|
||||
}
|
||||
|
||||
/// If this ndarray is unsized, return its sole value as an [`AnyObject`].
|
||||
/// Otherwise, do nothing and return the ndarray itself.
|
||||
// TODO: Rename to get_unsized_element
|
||||
pub fn split_unsized<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> ScalarOrNDArray<'ctx> {
|
||||
let Some(is_unsized) = self.is_unsized() else { todo!() };
|
||||
|
||||
if is_unsized {
|
||||
// NOTE: `np.size(self) == 0` here is never possible.
|
||||
let zero = generator.get_size_type(ctx.ctx).const_zero();
|
||||
let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) };
|
||||
|
||||
ScalarOrNDArray::Scalar(value)
|
||||
} else {
|
||||
ScalarOrNDArray::NDArray(*self)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
|
||||
@ -884,3 +911,21 @@ pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec<
|
||||
}
|
||||
strides
|
||||
}
|
||||
|
||||
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum ScalarOrNDArray<'ctx> {
|
||||
Scalar(BasicValueEnum<'ctx>),
|
||||
NDArray(NDArrayValue<'ctx>),
|
||||
}
|
||||
|
||||
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
|
||||
#[must_use]
|
||||
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
|
||||
match self {
|
||||
ScalarOrNDArray::Scalar(scalar) => scalar,
|
||||
ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user