forked from M-Labs/nac3
core/ndstrides: add NDArrayOut, broadcast_map and map
This commit is contained in:
parent
adca310424
commit
3efae534f7
|
@ -10,7 +10,7 @@ use super::*;
|
||||||
|
|
||||||
/// A [`Model`] of any [`BasicTypeEnum`].
|
/// A [`Model`] of any [`BasicTypeEnum`].
|
||||||
///
|
///
|
||||||
/// Use this when you don't need/cannot have any static types to escape from the [`Model`] abstraction.
|
/// Use this when you cannot know the type beforehand or cannot be abstracted with [`Model`].
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct Any<'ctx>(pub BasicTypeEnum<'ctx>);
|
pub struct Any<'ctx>(pub BasicTypeEnum<'ctx>);
|
||||||
|
|
||||||
|
|
|
@ -70,54 +70,3 @@ impl<'ctx, Element: Model<'ctx>> Instance<'ctx, Ptr<Array<Element>>> {
|
||||||
Ptr(self.model.0.item).check_value(generator, ctx.ctx, ptr).unwrap()
|
Ptr(self.model.0.item).check_value(generator, ctx.ctx, ptr).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Like [`ArrayModel`] but length is strongly-typed.
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
|
||||||
pub struct NArrayModel<const LEN: u32, Element>(pub Element);
|
|
||||||
pub type NArray<'ctx, const LEN: u32, Element> = Instance<'ctx, NArrayModel<LEN, Element>>;
|
|
||||||
|
|
||||||
impl<'ctx, const LEN: u32, Element: Model<'ctx>> NArrayModel<LEN, Element> {
|
|
||||||
/// Forget the `LEN` constant generic and get an [`ArrayModel`] with the same length.
|
|
||||||
pub fn forget_len(&self) -> Array<Element> {
|
|
||||||
Array { item: self.0, len: LEN }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, const LEN: u32, Element: Model<'ctx>> Model<'ctx> for NArrayModel<LEN, Element> {
|
|
||||||
type Value = ArrayValue<'ctx>;
|
|
||||||
type Type = ArrayType<'ctx>;
|
|
||||||
|
|
||||||
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
|
|
||||||
// Convenient implementation
|
|
||||||
self.forget_len().get_type(generator, ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &'ctx Context,
|
|
||||||
ty: T,
|
|
||||||
) -> Result<(), ModelError> {
|
|
||||||
// Convenient implementation
|
|
||||||
self.forget_len().check_type(generator, ctx, ty)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, const LEN: u32, Element: Model<'ctx>> Instance<'ctx, Ptr<NArrayModel<LEN, Element>>> {
|
|
||||||
/// Get the pointer to the `i`-th (0-based) array element.
|
|
||||||
pub fn at_const<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
i: u32,
|
|
||||||
name: &str,
|
|
||||||
) -> Instance<'ctx, Ptr<Element>> {
|
|
||||||
assert!(i < LEN);
|
|
||||||
|
|
||||||
let zero = ctx.ctx.i32_type().const_zero();
|
|
||||||
let i = ctx.ctx.i32_type().const_int(u64::from(i), false);
|
|
||||||
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() };
|
|
||||||
|
|
||||||
Ptr(self.model.0 .0).check_value(generator, ctx.ctx, ptr).unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
use inkwell::{context::Context, types::*, values::*};
|
use inkwell::{context::Context, types::*, values::*};
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||||
|
@ -110,6 +111,35 @@ pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
|
||||||
let p = generator.gen_array_var_alloc(ctx, ty, len, name)?;
|
let p = generator.gen_array_var_alloc(ctx, ty, len, name)?;
|
||||||
Ok(Ptr(*self).believe_value(PointerValue::from(p)))
|
Ok(Ptr(*self).believe_value(PointerValue::from(p)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn const_array<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
values: &[Instance<'ctx, Self>],
|
||||||
|
) -> Instance<'ctx, Array<Self>> {
|
||||||
|
macro_rules! make {
|
||||||
|
($t:expr, $into_value:expr) => {
|
||||||
|
$t.const_array(
|
||||||
|
&values
|
||||||
|
.iter()
|
||||||
|
.map(|x| $into_value(x.value.as_basic_value_enum()))
|
||||||
|
.collect_vec(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let value = match self.get_type(generator, ctx).as_basic_type_enum() {
|
||||||
|
BasicTypeEnum::ArrayType(t) => make!(t, BasicValueEnum::into_array_value),
|
||||||
|
BasicTypeEnum::IntType(t) => make!(t, BasicValueEnum::into_int_value),
|
||||||
|
BasicTypeEnum::FloatType(t) => make!(t, BasicValueEnum::into_float_value),
|
||||||
|
BasicTypeEnum::PointerType(t) => make!(t, BasicValueEnum::into_pointer_value),
|
||||||
|
BasicTypeEnum::StructType(t) => make!(t, BasicValueEnum::into_struct_value),
|
||||||
|
BasicTypeEnum::VectorType(t) => make!(t, BasicValueEnum::into_vector_value),
|
||||||
|
};
|
||||||
|
|
||||||
|
Array { len: values.len() as u32, item: *self }.check_value(generator, ctx, value).unwrap()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
|
|
@ -134,7 +134,7 @@ impl<'ctx, S: StructKind<'ctx>> Struct<S> {
|
||||||
/// Create a constant struct value.
|
/// Create a constant struct value.
|
||||||
///
|
///
|
||||||
/// This function also validates `fields` and panic when there is something wrong.
|
/// This function also validates `fields` and panic when there is something wrong.
|
||||||
fn const_struct<G: CodeGenerator + ?Sized>(
|
pub fn const_struct<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
|
|
|
@ -0,0 +1,220 @@
|
||||||
|
use inkwell::values::BasicValueEnum;
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
object::ndarray::{AnyObject, NDArrayObject},
|
||||||
|
stmt::gen_for_callback,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
typecheck::typedef::Type,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{nditer::NDIterHandle, NDArrayOut, ScalarOrNDArray};
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
/// Generate LLVM IR to broadcast `ndarray`s together, and starmap through them with `mapping` elementwise.
|
||||||
|
///
|
||||||
|
/// `mapping` is an LLVM IR generator. The input of `mapping` is the list of elements when iterating through
|
||||||
|
/// the input `ndarrays` after broadcasting. The output of `mapping` is the result of the elementwise operation.
|
||||||
|
///
|
||||||
|
/// `out` specifies whether the result should be a new ndarray or to be written an existing ndarray.
|
||||||
|
pub fn broadcast_starmap<'a, G, MappingFn>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
ndarrays: &[Self],
|
||||||
|
out: NDArrayOut<'ctx>,
|
||||||
|
mapping: MappingFn,
|
||||||
|
) -> Result<Self, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
MappingFn: FnOnce(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
|
&[BasicValueEnum<'ctx>],
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
|
{
|
||||||
|
// Broadcast inputs
|
||||||
|
let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays);
|
||||||
|
|
||||||
|
let out_ndarray = match out {
|
||||||
|
NDArrayOut::NewNDArray { dtype } => {
|
||||||
|
// Create a new ndarray based on the broadcast shape.
|
||||||
|
let result_ndarray =
|
||||||
|
NDArrayObject::alloca(generator, ctx, dtype, broadcast_result.ndims);
|
||||||
|
result_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
|
||||||
|
result_ndarray.create_data(generator, ctx);
|
||||||
|
result_ndarray
|
||||||
|
}
|
||||||
|
NDArrayOut::WriteToNDArray { ndarray: result_ndarray } => {
|
||||||
|
// Use an existing ndarray.
|
||||||
|
|
||||||
|
// Check that its shape is compatible with the broadcast shape.
|
||||||
|
result_ndarray.assert_can_be_written_by_out(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
broadcast_result.ndims,
|
||||||
|
broadcast_result.shape,
|
||||||
|
);
|
||||||
|
result_ndarray
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Map element-wise and store results into `mapped_ndarray`.
|
||||||
|
let nditer = NDIterHandle::new(generator, ctx, out_ndarray);
|
||||||
|
gen_for_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
Some("broadcast_starmap"),
|
||||||
|
|generator, ctx| {
|
||||||
|
// Create NDIters for all broadcasted input ndarrays.
|
||||||
|
let other_nditers = broadcast_result
|
||||||
|
.ndarrays
|
||||||
|
.iter()
|
||||||
|
.map(|ndarray| NDIterHandle::new(generator, ctx, *ndarray))
|
||||||
|
.collect_vec();
|
||||||
|
Ok((nditer, other_nditers))
|
||||||
|
},
|
||||||
|
|generator, ctx, (out_nditer, _in_nditers)| {
|
||||||
|
// We can simply use `out_nditer`'s `has_next()`.
|
||||||
|
// `in_nditers`' `has_next()`s should return the same value.
|
||||||
|
Ok(out_nditer.has_next(generator, ctx).value)
|
||||||
|
},
|
||||||
|
|generator, ctx, _hooks, (out_nditer, in_nditers)| {
|
||||||
|
// Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`,
|
||||||
|
// and write to `out_ndarray`.
|
||||||
|
|
||||||
|
let in_scalars = in_nditers
|
||||||
|
.iter()
|
||||||
|
.map(|nditer| nditer.get_scalar(generator, ctx).value)
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
let result = mapping(generator, ctx, &in_scalars)?;
|
||||||
|
|
||||||
|
let p = out_nditer.get_pointer(generator, ctx);
|
||||||
|
ctx.builder.build_store(p, result).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
|generator, ctx, (out_nditer, in_nditers)| {
|
||||||
|
// Advance all iterators
|
||||||
|
out_nditer.next(generator, ctx);
|
||||||
|
in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx));
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(out_ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map through this ndarray with an elementwise function.
|
||||||
|
pub fn map<'a, G, Mapping>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
out: NDArrayOut<'ctx>,
|
||||||
|
mapping: Mapping,
|
||||||
|
) -> Result<Self, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
Mapping: FnOnce(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
|
BasicValueEnum<'ctx>,
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
|
{
|
||||||
|
NDArrayObject::broadcast_starmap(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
&[*self],
|
||||||
|
out,
|
||||||
|
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
|
/// Starmap through a list of inputs using `mapping`, where an input could be an ndarray, a scalar.
|
||||||
|
///
|
||||||
|
/// This function is very helpful when implementing NumPy functions that takes on either scalars or ndarrays or a mix of them
|
||||||
|
/// as their inputs and produces either an ndarray with broadcast, or a scalar if all its inputs are all scalars.
|
||||||
|
///
|
||||||
|
/// For example ,this function can be used to implement `np.add`, which has the following behaviors:
|
||||||
|
/// - `np.add(3, 4) = 7` # (scalar, scalar) -> scalar
|
||||||
|
/// - `np.add(3, np.array([4, 5, 6]))` # (scalar, ndarray) -> ndarray; the first `scalar` is converted into an ndarray and broadcasted.
|
||||||
|
/// - `np.add(np.array([[1], [2], [3]]), np.array([[4, 5, 6]]))` # (ndarray, ndarray) -> ndarray; there is broadcasting.
|
||||||
|
///
|
||||||
|
/// ## Details:
|
||||||
|
///
|
||||||
|
/// If `inputs` are all [`ScalarOrNDArray::Scalar`], the output will be a [`ScalarOrNDArray::Scalar`] with type `ret_dtype`.
|
||||||
|
///
|
||||||
|
/// Otherwise (if there are any [`ScalarOrNDArray::NDArray`] in `inputs`), all inputs will be 'as-ndarray'-ed into ndarrays,
|
||||||
|
/// then all inputs (now all ndarrays) will be passed to [`NDArrayObject::broadcasting_starmap`] and **create** a new ndarray
|
||||||
|
/// with dtype `ret_dtype`.
|
||||||
|
pub fn broadcasting_starmap<'a, G, MappingFn>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
inputs: &[ScalarOrNDArray<'ctx>],
|
||||||
|
ret_dtype: Type,
|
||||||
|
mapping: MappingFn,
|
||||||
|
) -> Result<ScalarOrNDArray<'ctx>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
MappingFn: FnOnce(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
|
&[BasicValueEnum<'ctx>],
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
|
{
|
||||||
|
// Check if all inputs are Scalars
|
||||||
|
let all_scalars: Option<Vec<_>> = inputs.iter().map(AnyObject::try_from).try_collect().ok();
|
||||||
|
|
||||||
|
if let Some(scalars) = all_scalars {
|
||||||
|
let scalars = scalars.iter().map(|scalar| scalar.value).collect_vec();
|
||||||
|
let value = mapping(generator, ctx, &scalars)?;
|
||||||
|
|
||||||
|
Ok(ScalarOrNDArray::Scalar(AnyObject { ty: ret_dtype, value }))
|
||||||
|
} else {
|
||||||
|
// Promote all input to ndarrays and map through them.
|
||||||
|
let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec();
|
||||||
|
let ndarray = NDArrayObject::broadcast_starmap(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
&inputs,
|
||||||
|
NDArrayOut::NewNDArray { dtype: ret_dtype },
|
||||||
|
mapping,
|
||||||
|
)?;
|
||||||
|
Ok(ScalarOrNDArray::NDArray(ndarray))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map through this [`ScalarOrNDArray`] with an elementwise function.
|
||||||
|
///
|
||||||
|
/// If this is a scalar, `mapping` will directly act on the scalar. This function will return a [`ScalarOrNDArray::Scalar`] of that result.
|
||||||
|
///
|
||||||
|
/// If this is an ndarray, `mapping` will be applied to the elements of the ndarray. A new ndarray of the results will be created and
|
||||||
|
/// returned as a [`ScalarOrNDArray::NDArray`].
|
||||||
|
pub fn map<'a, G, Mapping>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
ret_dtype: Type,
|
||||||
|
mapping: Mapping,
|
||||||
|
) -> Result<ScalarOrNDArray<'ctx>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
Mapping: FnOnce(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
|
BasicValueEnum<'ctx>,
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
|
{
|
||||||
|
ScalarOrNDArray::broadcasting_starmap(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
&[*self],
|
||||||
|
ret_dtype,
|
||||||
|
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,6 +2,7 @@ pub mod array;
|
||||||
pub mod broadcast;
|
pub mod broadcast;
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
pub mod indexing;
|
pub mod indexing;
|
||||||
|
pub mod map;
|
||||||
pub mod nditer;
|
pub mod nditer;
|
||||||
pub mod shape_util;
|
pub mod shape_util;
|
||||||
pub mod view;
|
pub mod view;
|
||||||
|
@ -20,6 +21,7 @@ use crate::{
|
||||||
call_nac3_ndarray_get_pelement_by_indices, call_nac3_ndarray_is_c_contiguous,
|
call_nac3_ndarray_get_pelement_by_indices, call_nac3_ndarray_is_c_contiguous,
|
||||||
call_nac3_ndarray_len, call_nac3_ndarray_nbytes,
|
call_nac3_ndarray_len, call_nac3_ndarray_nbytes,
|
||||||
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
|
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
|
||||||
|
call_nac3_ndarray_util_assert_output_shape_same,
|
||||||
},
|
},
|
||||||
model::*,
|
model::*,
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
|
@ -506,6 +508,31 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
ndarray.instance.set(ctx, |f| f.data, data);
|
ndarray.instance.set(ctx, |f| f.data, data);
|
||||||
ndarray
|
ndarray
|
||||||
}
|
}
|
||||||
|
/// Check if this `NDArray` can be used as an `out` ndarray for an operation.
|
||||||
|
///
|
||||||
|
/// Raise an exception if the shapes do not match.
|
||||||
|
pub fn assert_can_be_written_by_out<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
out_ndims: u64,
|
||||||
|
out_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
) {
|
||||||
|
let ndarray_ndims = self.ndims_llvm(generator, ctx.ctx);
|
||||||
|
let ndarray_shape = self.instance.get(generator, ctx, |f| f.shape);
|
||||||
|
|
||||||
|
let output_ndims = Int(SizeT).const_int(generator, ctx.ctx, out_ndims);
|
||||||
|
let output_shape = out_shape;
|
||||||
|
|
||||||
|
call_nac3_ndarray_util_assert_output_shape_same(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray_ndims,
|
||||||
|
ndarray_shape,
|
||||||
|
output_ndims,
|
||||||
|
output_shape,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.
|
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.
|
||||||
|
@ -588,3 +615,27 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// An helper enum specifying how a function should produce its output.
|
||||||
|
///
|
||||||
|
/// Many functions in NumPy has an optional `out` parameter (e.g., `matmul`). If `out` is specified
|
||||||
|
/// with an ndarray, the result of a function will be written to `out`. If `out` is not specified, a function will
|
||||||
|
/// create a new ndarray and store the result in it.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum NDArrayOut<'ctx> {
|
||||||
|
/// Tell a function should create a new ndarray with the expected element type `dtype`.
|
||||||
|
NewNDArray { dtype: Type },
|
||||||
|
/// Tell a function to write the result to `ndarray`.
|
||||||
|
WriteToNDArray { ndarray: NDArrayObject<'ctx> },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayOut<'ctx> {
|
||||||
|
/// Get the dtype of this output.
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_dtype(&self) -> Type {
|
||||||
|
match self {
|
||||||
|
NDArrayOut::NewNDArray { dtype } => *dtype,
|
||||||
|
NDArrayOut::WriteToNDArray { ndarray } => ndarray.dtype,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue