[core] codegen/values/ndarray: Add more ScalarOrNDArray utils

Based on f731e604: core/ndstrides: add more ScalarOrNDArray and
NDArrayObject utils
This commit is contained in:
David Mak 2024-12-18 16:32:34 +08:00
parent d1bf5085a6
commit 8d8f9e9b2a

View File

@ -12,13 +12,16 @@ use super::{
TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
UntypedArrayLikeMutator,
};
use crate::codegen::{
irrt,
llvm_intrinsics::{call_int_umin, call_memcpy_generic_array},
stmt::gen_for_callback_incrementing,
type_aligned_alloca,
types::{ndarray::NDArrayType, structure::StructField, TupleType},
CodeGenContext, CodeGenerator,
use crate::{
codegen::{
irrt,
llvm_intrinsics::{call_int_umin, call_memcpy_generic_array},
stmt::gen_for_callback_incrementing,
type_aligned_alloca,
types::{ndarray::NDArrayType, structure::StructField, TupleType},
CodeGenContext, CodeGenerator,
},
typecheck::typedef::{Type, TypeEnum},
};
pub use broadcast::*;
pub use contiguous::*;
@ -976,7 +979,52 @@ pub enum ScalarOrNDArray<'ctx> {
NDArray(NDArrayValue<'ctx>),
}
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for BasicValueEnum<'ctx> {
type Error = ();
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
match value {
ScalarOrNDArray::Scalar(scalar) => Ok(*scalar),
ScalarOrNDArray::NDArray(_) => Err(()),
}
}
}
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayValue<'ctx> {
type Error = ();
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
match value {
ScalarOrNDArray::Scalar(_) => Err(()),
ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray),
}
}
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// Split on `object` either into a scalar or an ndarray.
///
/// If `object` is an ndarray, [`ScalarOrNDArray::NDArray`].
///
/// For everything else, it is wrapped with [`ScalarOrNDArray::Scalar`].
pub fn from_value<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
(object_ty, object): (Type, BasicValueEnum<'ctx>),
) -> ScalarOrNDArray<'ctx> {
match &*ctx.unifier.get_ty(object_ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty)
.map_value(object.into_pointer_value(), None);
ScalarOrNDArray::NDArray(ndarray)
}
_ => ScalarOrNDArray::Scalar(object),
}
}
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
#[must_use]
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
@ -985,4 +1033,31 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(),
}
}
/// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`.
///
/// - If this is an ndarray, the ndarray is returned.
/// - If this is a scalar, this function returns new ndarray created with [`NDArrayObject::make_unsized`].
pub fn to_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> NDArrayValue<'ctx> {
match self {
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
ScalarOrNDArray::Scalar(scalar) => {
NDArrayType::new_unsized(generator, ctx.ctx, scalar.get_type())
.construct_unsized(generator, ctx, scalar, None)
}
}
}
/// Get the dtype of the ndarray created if this were called with [`ScalarOrNDArray::to_ndarray`].
#[must_use]
pub fn get_dtype(&self) -> BasicTypeEnum<'ctx> {
match self {
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
ScalarOrNDArray::Scalar(scalar) => scalar.get_type(),
}
}
}