[core] codegen/values/ndarray: Add more ScalarOrNDArray utils

Based on f731e604: core/ndstrides: add more ScalarOrNDArray and
NDArrayObject utils
This commit is contained in:
David Mak 2024-12-18 16:32:34 +08:00
parent 7375983e0c
commit dcde1d9c87

View File

@ -12,13 +12,16 @@ use super::{
TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
UntypedArrayLikeMutator,
};
use crate::codegen::{
irrt,
llvm_intrinsics::{call_int_umin, call_memcpy_generic_array},
stmt::gen_for_callback_incrementing,
type_aligned_alloca,
types::{ndarray::NDArrayType, structure::StructField, TupleType},
CodeGenContext, CodeGenerator,
use crate::{
codegen::{
irrt,
llvm_intrinsics::{call_int_umin, call_memcpy_generic_array},
stmt::gen_for_callback_incrementing,
type_aligned_alloca,
types::{ndarray::NDArrayType, structure::StructField, TupleType},
CodeGenContext, CodeGenerator,
},
typecheck::typedef::{Type, TypeEnum},
};
pub use broadcast::*;
pub use contiguous::*;
@ -501,22 +504,38 @@ impl<'ctx> NDArrayValue<'ctx> {
self.ndims.map(|ndims| ndims == 0)
}
/// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`].
/// Otherwise, do nothing and return the ndarray itself.
// TODO: Rename to get_unsized_element
pub fn split_unsized<G: CodeGenerator + ?Sized>(
/// Returns the element present in this `ndarray` if this is unsized.
pub fn get_unsized_element<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> ScalarOrNDArray<'ctx> {
let Some(is_unsized) = self.is_unsized() else { todo!() };
) -> Option<BasicValueEnum<'ctx>> {
let Some(is_unsized) = self.is_unsized() else {
panic!("NDArrayValue::get_unsized_element can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))");
};
if is_unsized {
// NOTE: `np.size(self) == 0` here is never possible.
let zero = generator.get_size_type(ctx.ctx).const_zero();
let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) };
ScalarOrNDArray::Scalar(value)
Some(value)
} else {
None
}
}
/// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`].
/// Otherwise, do nothing and return the ndarray itself.
pub fn split_unsized<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> ScalarOrNDArray<'ctx> {
assert!(self.ndims.is_some(), "NDArrayValue::split_unsized can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))");
if let Some(unsized_elem) = self.get_unsized_element(generator, ctx) {
ScalarOrNDArray::Scalar(unsized_elem)
} else {
ScalarOrNDArray::NDArray(*self)
}
@ -978,7 +997,52 @@ pub enum ScalarOrNDArray<'ctx> {
NDArray(NDArrayValue<'ctx>),
}
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for BasicValueEnum<'ctx> {
type Error = ();
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
match value {
ScalarOrNDArray::Scalar(scalar) => Ok(*scalar),
ScalarOrNDArray::NDArray(_) => Err(()),
}
}
}
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayValue<'ctx> {
type Error = ();
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
match value {
ScalarOrNDArray::Scalar(_) => Err(()),
ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray),
}
}
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// Split on `object` either into a scalar or an ndarray.
///
/// If `object` is an ndarray, [`ScalarOrNDArray::NDArray`].
///
/// For everything else, it is wrapped with [`ScalarOrNDArray::Scalar`].
pub fn from_value<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
(object_ty, object): (Type, BasicValueEnum<'ctx>),
) -> ScalarOrNDArray<'ctx> {
match &*ctx.unifier.get_ty(object_ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty)
.map_value(object.into_pointer_value(), None);
ScalarOrNDArray::NDArray(ndarray)
}
_ => ScalarOrNDArray::Scalar(object),
}
}
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
#[must_use]
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
@ -987,4 +1051,33 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(),
}
}
/// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`.
///
/// - If this is an ndarray, the ndarray is returned.
/// - If this is a scalar, this function returns new ndarray created with
/// [`NDArrayType::construct_unsized`].
pub fn to_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> NDArrayValue<'ctx> {
match self {
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
ScalarOrNDArray::Scalar(scalar) => {
NDArrayType::new_unsized(generator, ctx.ctx, scalar.get_type())
.construct_unsized(generator, ctx, scalar, None)
}
}
}
/// Get the dtype of the ndarray created if this were called with
/// [`ScalarOrNDArray::to_ndarray`].
#[must_use]
pub fn get_dtype(&self) -> BasicTypeEnum<'ctx> {
match self {
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
ScalarOrNDArray::Scalar(scalar) => scalar.get_type(),
}
}
}