forked from M-Labs/nac3
[core] codegen: Implement ScalarOrNDArray and use it in indexing
Based on 8f9d2d82: core/ndstrides: implement ndarray indexing.
This commit is contained in:
parent
438943ac6f
commit
dc91d9e35a
@ -34,7 +34,8 @@ use super::{
|
|||||||
},
|
},
|
||||||
types::{ndarray::NDArrayType, ListType},
|
types::{ndarray::NDArrayType, ListType},
|
||||||
values::{
|
values::{
|
||||||
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue,
|
ndarray::{NDArrayValue, RustNDIndex},
|
||||||
|
ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue,
|
||||||
TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
|
TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
|
||||||
},
|
},
|
||||||
CodeGenContext, CodeGenTask, CodeGenerator,
|
CodeGenContext, CodeGenTask, CodeGenerator,
|
||||||
@ -3486,19 +3487,21 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (ty, ndims) =
|
let Some(ndarray) = generator.gen_expr(ctx, value)? else {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, value.custom.unwrap());
|
|
||||||
|
|
||||||
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
|
|
||||||
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
|
|
||||||
.into_pointer_value()
|
|
||||||
} else {
|
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
let v = NDArrayType::from_unifier_type(generator, ctx, value.custom.unwrap())
|
|
||||||
.map_value(v, None);
|
|
||||||
|
|
||||||
return gen_ndarray_subscript_expr(generator, ctx, ty, ndims, v, slice);
|
let ndarray_ty = value.custom.unwrap();
|
||||||
|
let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||||
|
let ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty)
|
||||||
|
.map_value(ndarray.into_pointer_value(), None);
|
||||||
|
|
||||||
|
let indices = RustNDIndex::from_subscript_expr(generator, ctx, slice)?;
|
||||||
|
let result = ndarray
|
||||||
|
.index(generator, ctx, &indices)
|
||||||
|
.split_unsized(generator, ctx)
|
||||||
|
.to_basic_value_enum();
|
||||||
|
return Ok(Some(ValueEnum::Dynamic(result)));
|
||||||
}
|
}
|
||||||
TypeEnum::TTuple { .. } => {
|
TypeEnum::TTuple { .. } => {
|
||||||
let index: u32 =
|
let index: u32 =
|
||||||
|
@ -403,6 +403,33 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
assert_eq!(self.dtype, src.dtype, "self and src dtype should match");
|
assert_eq!(self.dtype, src.dtype, "self and src dtype should match");
|
||||||
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
|
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
|
||||||
|
#[must_use]
|
||||||
|
pub fn is_unsized(&self) -> Option<bool> {
|
||||||
|
self.ndims.map(|ndims| ndims == 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If this ndarray is unsized, return its sole value as an [`AnyObject`].
|
||||||
|
/// Otherwise, do nothing and return the ndarray itself.
|
||||||
|
// TODO: Rename to get_unsized_element
|
||||||
|
pub fn split_unsized<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> ScalarOrNDArray<'ctx> {
|
||||||
|
let Some(is_unsized) = self.is_unsized() else { todo!() };
|
||||||
|
|
||||||
|
if is_unsized {
|
||||||
|
// NOTE: `np.size(self) == 0` here is never possible.
|
||||||
|
let zero = generator.get_size_type(ctx.ctx).const_zero();
|
||||||
|
let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) };
|
||||||
|
|
||||||
|
ScalarOrNDArray::Scalar(value)
|
||||||
|
} else {
|
||||||
|
ScalarOrNDArray::NDArray(*self)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
|
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
|
||||||
@ -884,3 +911,21 @@ pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec<
|
|||||||
}
|
}
|
||||||
strides
|
strides
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub enum ScalarOrNDArray<'ctx> {
|
||||||
|
Scalar(BasicValueEnum<'ctx>),
|
||||||
|
NDArray(NDArrayValue<'ctx>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
|
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
|
||||||
|
match self {
|
||||||
|
ScalarOrNDArray::Scalar(scalar) => scalar,
|
||||||
|
ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user