[core] codegen/values/ndarray: Add more ScalarOrNDArray utils
Based on f731e604
: core/ndstrides: add more ScalarOrNDArray and
NDArrayObject utils
This commit is contained in:
parent
7375983e0c
commit
dcde1d9c87
@ -12,13 +12,16 @@ use super::{
|
||||
TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
|
||||
UntypedArrayLikeMutator,
|
||||
};
|
||||
use crate::codegen::{
|
||||
irrt,
|
||||
llvm_intrinsics::{call_int_umin, call_memcpy_generic_array},
|
||||
stmt::gen_for_callback_incrementing,
|
||||
type_aligned_alloca,
|
||||
types::{ndarray::NDArrayType, structure::StructField, TupleType},
|
||||
CodeGenContext, CodeGenerator,
|
||||
use crate::{
|
||||
codegen::{
|
||||
irrt,
|
||||
llvm_intrinsics::{call_int_umin, call_memcpy_generic_array},
|
||||
stmt::gen_for_callback_incrementing,
|
||||
type_aligned_alloca,
|
||||
types::{ndarray::NDArrayType, structure::StructField, TupleType},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::{Type, TypeEnum},
|
||||
};
|
||||
pub use broadcast::*;
|
||||
pub use contiguous::*;
|
||||
@ -501,22 +504,38 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
self.ndims.map(|ndims| ndims == 0)
|
||||
}
|
||||
|
||||
/// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`].
|
||||
/// Otherwise, do nothing and return the ndarray itself.
|
||||
// TODO: Rename to get_unsized_element
|
||||
pub fn split_unsized<G: CodeGenerator + ?Sized>(
|
||||
/// Returns the element present in this `ndarray` if this is unsized.
|
||||
pub fn get_unsized_element<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> ScalarOrNDArray<'ctx> {
|
||||
let Some(is_unsized) = self.is_unsized() else { todo!() };
|
||||
) -> Option<BasicValueEnum<'ctx>> {
|
||||
let Some(is_unsized) = self.is_unsized() else {
|
||||
panic!("NDArrayValue::get_unsized_element can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))");
|
||||
};
|
||||
|
||||
if is_unsized {
|
||||
// NOTE: `np.size(self) == 0` here is never possible.
|
||||
let zero = generator.get_size_type(ctx.ctx).const_zero();
|
||||
let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) };
|
||||
|
||||
ScalarOrNDArray::Scalar(value)
|
||||
Some(value)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`].
|
||||
/// Otherwise, do nothing and return the ndarray itself.
|
||||
pub fn split_unsized<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> ScalarOrNDArray<'ctx> {
|
||||
assert!(self.ndims.is_some(), "NDArrayValue::split_unsized can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))");
|
||||
|
||||
if let Some(unsized_elem) = self.get_unsized_element(generator, ctx) {
|
||||
ScalarOrNDArray::Scalar(unsized_elem)
|
||||
} else {
|
||||
ScalarOrNDArray::NDArray(*self)
|
||||
}
|
||||
@ -978,7 +997,52 @@ pub enum ScalarOrNDArray<'ctx> {
|
||||
NDArray(NDArrayValue<'ctx>),
|
||||
}
|
||||
|
||||
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for BasicValueEnum<'ctx> {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
ScalarOrNDArray::Scalar(scalar) => Ok(*scalar),
|
||||
ScalarOrNDArray::NDArray(_) => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayValue<'ctx> {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
ScalarOrNDArray::Scalar(_) => Err(()),
|
||||
ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
/// Split on `object` either into a scalar or an ndarray.
|
||||
///
|
||||
/// If `object` is an ndarray, [`ScalarOrNDArray::NDArray`].
|
||||
///
|
||||
/// For everything else, it is wrapped with [`ScalarOrNDArray::Scalar`].
|
||||
pub fn from_value<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
(object_ty, object): (Type, BasicValueEnum<'ctx>),
|
||||
) -> ScalarOrNDArray<'ctx> {
|
||||
match &*ctx.unifier.get_ty(object_ty) {
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty)
|
||||
.map_value(object.into_pointer_value(), None);
|
||||
ScalarOrNDArray::NDArray(ndarray)
|
||||
}
|
||||
|
||||
_ => ScalarOrNDArray::Scalar(object),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
|
||||
#[must_use]
|
||||
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
|
||||
@ -987,4 +1051,33 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`.
|
||||
///
|
||||
/// - If this is an ndarray, the ndarray is returned.
|
||||
/// - If this is a scalar, this function returns new ndarray created with
|
||||
/// [`NDArrayType::construct_unsized`].
|
||||
pub fn to_ndarray<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> NDArrayValue<'ctx> {
|
||||
match self {
|
||||
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
|
||||
ScalarOrNDArray::Scalar(scalar) => {
|
||||
NDArrayType::new_unsized(generator, ctx.ctx, scalar.get_type())
|
||||
.construct_unsized(generator, ctx, scalar, None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the dtype of the ndarray created if this were called with
|
||||
/// [`ScalarOrNDArray::to_ndarray`].
|
||||
#[must_use]
|
||||
pub fn get_dtype(&self) -> BasicTypeEnum<'ctx> {
|
||||
match self {
|
||||
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
|
||||
ScalarOrNDArray::Scalar(scalar) => scalar.get_type(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user