forked from M-Labs/nac3
[core] codegen/values/ndarray: Add more ScalarOrNDArray utils
Based on f731e604
: core/ndstrides: add more ScalarOrNDArray and
NDArrayObject utils
This commit is contained in:
parent
7375983e0c
commit
dcde1d9c87
@ -12,13 +12,16 @@ use super::{
|
|||||||
TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
|
TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
|
||||||
UntypedArrayLikeMutator,
|
UntypedArrayLikeMutator,
|
||||||
};
|
};
|
||||||
use crate::codegen::{
|
use crate::{
|
||||||
|
codegen::{
|
||||||
irrt,
|
irrt,
|
||||||
llvm_intrinsics::{call_int_umin, call_memcpy_generic_array},
|
llvm_intrinsics::{call_int_umin, call_memcpy_generic_array},
|
||||||
stmt::gen_for_callback_incrementing,
|
stmt::gen_for_callback_incrementing,
|
||||||
type_aligned_alloca,
|
type_aligned_alloca,
|
||||||
types::{ndarray::NDArrayType, structure::StructField, TupleType},
|
types::{ndarray::NDArrayType, structure::StructField, TupleType},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
};
|
};
|
||||||
pub use broadcast::*;
|
pub use broadcast::*;
|
||||||
pub use contiguous::*;
|
pub use contiguous::*;
|
||||||
@ -501,22 +504,38 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
self.ndims.map(|ndims| ndims == 0)
|
self.ndims.map(|ndims| ndims == 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`].
|
/// Returns the element present in this `ndarray` if this is unsized.
|
||||||
/// Otherwise, do nothing and return the ndarray itself.
|
pub fn get_unsized_element<G: CodeGenerator + ?Sized>(
|
||||||
// TODO: Rename to get_unsized_element
|
|
||||||
pub fn split_unsized<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
) -> ScalarOrNDArray<'ctx> {
|
) -> Option<BasicValueEnum<'ctx>> {
|
||||||
let Some(is_unsized) = self.is_unsized() else { todo!() };
|
let Some(is_unsized) = self.is_unsized() else {
|
||||||
|
panic!("NDArrayValue::get_unsized_element can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))");
|
||||||
|
};
|
||||||
|
|
||||||
if is_unsized {
|
if is_unsized {
|
||||||
// NOTE: `np.size(self) == 0` here is never possible.
|
// NOTE: `np.size(self) == 0` here is never possible.
|
||||||
let zero = generator.get_size_type(ctx.ctx).const_zero();
|
let zero = generator.get_size_type(ctx.ctx).const_zero();
|
||||||
let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) };
|
let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) };
|
||||||
|
|
||||||
ScalarOrNDArray::Scalar(value)
|
Some(value)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If this ndarray is unsized, return its sole value as an [`BasicValueEnum`].
|
||||||
|
/// Otherwise, do nothing and return the ndarray itself.
|
||||||
|
pub fn split_unsized<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> ScalarOrNDArray<'ctx> {
|
||||||
|
assert!(self.ndims.is_some(), "NDArrayValue::split_unsized can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))");
|
||||||
|
|
||||||
|
if let Some(unsized_elem) = self.get_unsized_element(generator, ctx) {
|
||||||
|
ScalarOrNDArray::Scalar(unsized_elem)
|
||||||
} else {
|
} else {
|
||||||
ScalarOrNDArray::NDArray(*self)
|
ScalarOrNDArray::NDArray(*self)
|
||||||
}
|
}
|
||||||
@ -978,7 +997,52 @@ pub enum ScalarOrNDArray<'ctx> {
|
|||||||
NDArray(NDArrayValue<'ctx>),
|
NDArray(NDArrayValue<'ctx>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for BasicValueEnum<'ctx> {
|
||||||
|
type Error = ();
|
||||||
|
|
||||||
|
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
|
||||||
|
match value {
|
||||||
|
ScalarOrNDArray::Scalar(scalar) => Ok(*scalar),
|
||||||
|
ScalarOrNDArray::NDArray(_) => Err(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayValue<'ctx> {
|
||||||
|
type Error = ();
|
||||||
|
|
||||||
|
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
|
||||||
|
match value {
|
||||||
|
ScalarOrNDArray::Scalar(_) => Err(()),
|
||||||
|
ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<'ctx> ScalarOrNDArray<'ctx> {
|
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
|
/// Split on `object` either into a scalar or an ndarray.
|
||||||
|
///
|
||||||
|
/// If `object` is an ndarray, [`ScalarOrNDArray::NDArray`].
|
||||||
|
///
|
||||||
|
/// For everything else, it is wrapped with [`ScalarOrNDArray::Scalar`].
|
||||||
|
pub fn from_value<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
(object_ty, object): (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> ScalarOrNDArray<'ctx> {
|
||||||
|
match &*ctx.unifier.get_ty(object_ty) {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty)
|
||||||
|
.map_value(object.into_pointer_value(), None);
|
||||||
|
ScalarOrNDArray::NDArray(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => ScalarOrNDArray::Scalar(object),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
|
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
|
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
|
||||||
@ -987,4 +1051,33 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
|||||||
ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(),
|
ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`.
|
||||||
|
///
|
||||||
|
/// - If this is an ndarray, the ndarray is returned.
|
||||||
|
/// - If this is a scalar, this function returns new ndarray created with
|
||||||
|
/// [`NDArrayType::construct_unsized`].
|
||||||
|
pub fn to_ndarray<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> NDArrayValue<'ctx> {
|
||||||
|
match self {
|
||||||
|
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
|
||||||
|
ScalarOrNDArray::Scalar(scalar) => {
|
||||||
|
NDArrayType::new_unsized(generator, ctx.ctx, scalar.get_type())
|
||||||
|
.construct_unsized(generator, ctx, scalar, None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the dtype of the ndarray created if this were called with
|
||||||
|
/// [`ScalarOrNDArray::to_ndarray`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_dtype(&self) -> BasicTypeEnum<'ctx> {
|
||||||
|
match self {
|
||||||
|
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
|
||||||
|
ScalarOrNDArray::Scalar(scalar) => scalar.get_type(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user