diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 84a4c3f..9e3faf2 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -8,7 +8,10 @@ use std::{ use inkwell::{ attributes::{Attribute, AttributeLoc}, types::{AnyType, BasicType, BasicTypeEnum}, - values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue, StructValue}, + values::{ + BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue, + StructValue, + }, AddressSpace, IntPredicate, OptimizationLevel, }; use itertools::{chain, izip, Either, Itertools}; @@ -20,8 +23,8 @@ use nac3parser::ast::{ use super::{ classes::{ - ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType, ProxyValue, - RangeValue, UntypedArrayLikeAccessor, + ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, ProxyType, ProxyValue, RangeValue, + UntypedArrayLikeAccessor, }, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name, @@ -31,10 +34,13 @@ use super::{ call_int_umin, call_memcpy_generic, }, macros::codegen_unreachable, - need_sret, numpy, + need_sret, object::{ any::AnyObject, - ndarray::{indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject}, + ndarray::{ + indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject, NDArrayOut, + ScalarOrNDArray, + }, }, stmt::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, @@ -44,7 +50,7 @@ use super::{ }; use crate::{ symbol_resolver::{SymbolValue, ValueEnum}, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, + toplevel::{helper::PrimDef, DefinitionId, TopLevelDef}, typecheck::{ magic_methods::{Binop, BinopVariant, HasOpInfo}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, @@ -1549,99 +1555,71 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let left = + ScalarOrNDArray::split_object(generator, ctx, AnyObject { ty: ty1, value: left_val }); + let right = + ScalarOrNDArray::split_object(generator, ctx, AnyObject { ty: ty2, value: right_val }); - let is_ndarray1 = ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + // Inhomogeneous binary operations are not supported. + assert!(ctx.unifier.unioned(left.get_dtype(), right.get_dtype())); - if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2); + let common_dtype = left.get_dtype(); - assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + let out = match op.variant { + BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: common_dtype }, + BinopVariant::AugAssign => { + // If this is an augmented assignment. + // `left` has to be an ndarray. If it were a scalar then NAC3 simply doesn't support it. + if let ScalarOrNDArray::NDArray(out_ndarray) = left { + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } + } else { + panic!("left must be an ndarray") + } + } + }; - let left_val = - NDArrayValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None); - let right_val = - NDArrayValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None); - - let res = if op.base == Operator::MatMult { - // MatMult is the only binop which is not an elementwise op - numpy::ndarray_matmul_2d( - generator, - ctx, - ndarray_dtype1, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(left_val), - }, - left_val, - right_val, - )? - } else { - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ndarray_dtype1, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(left_val), - }, - (left_val.as_base_value().into(), false), - (right_val.as_base_value().into(), false), - |generator, ctx, (lhs, rhs)| { - gen_binop_expr_with_values( - generator, - ctx, - (&Some(ndarray_dtype1), lhs), - op, - (&Some(ndarray_dtype2), rhs), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ndarray_dtype1, - ) - }, - )? - }; - - Ok(Some(res.as_base_value().into())) + if op.base == Operator::MatMult { + // Handle matrix multiplication. + todo!() } else { - let (ndarray_dtype, _) = - unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); - let ndarray_val = NDArrayValue::from_ptr_val( - if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), - llvm_usize, - None, - ); - let res = numpy::ndarray_elementwise_binop_impl( + // For other operations, they are all elementwise operations. + + // There are only three cases: + // - LHS is a scalar, RHS is an ndarray. + // - LHS is an ndarray, RHS is a scalar. + // - LHS is an ndarray, RHS is an ndarray. + // + // For all cases, the scalar operand is promoted to an ndarray, + // the two are then broadcasted, and starmapped through. + + let left = left.to_ndarray(generator, ctx); + let right = right.to_ndarray(generator, ctx); + + let result = NDArrayObject::broadcast_starmap( generator, ctx, - ndarray_dtype, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(ndarray_val), - }, - (left_val, !is_ndarray1), - (right_val, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - gen_binop_expr_with_values( + &[left, right], + out, + |generator, ctx, scalars| { + let left_value = scalars[0]; + let right_value = scalars[1]; + + let result = gen_binop_expr_with_values( generator, ctx, - (&Some(ndarray_dtype), lhs), + (&Some(left.dtype), left_value), op, - (&Some(ndarray_dtype), rhs), + (&Some(right.dtype), right_value), ctx.current_loc, )? .unwrap() - .to_basic_value_enum(ctx, generator, ndarray_dtype) - }, - )?; + .to_basic_value_enum(ctx, generator, common_dtype)?; - Ok(Some(res.as_base_value().into())) + Ok(result) + }, + ) + .unwrap(); + Ok(Some(ValueEnum::Dynamic(result.instance.value.as_basic_value_enum()))) } } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); @@ -1799,14 +1777,12 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( _ => val.into(), } } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let llvm_usize = generator.get_size_type(ctx.ctx); - let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); - - let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None); + let ndarray = AnyObject { value: val, ty }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // passing it to the elementwise codegen function - let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) { + let op = if ndarray.dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) { if op == ast::Unaryop::Invert { ast::Unaryop::Not } else { @@ -1820,20 +1796,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( op }; - let res = numpy::ndarray_elementwise_unaryop_impl( + let mapped_ndarray = ndarray.map( generator, ctx, - ndarray_dtype, - None, - val, - |generator, ctx, val| { - gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))? + NDArrayOut::NewNDArray { dtype: ndarray.dtype }, + |generator, ctx, scalar| { + gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray.dtype), scalar))? .unwrap() - .to_basic_value_enum(ctx, generator, ndarray_dtype) + .to_basic_value_enum(ctx, generator, ndarray.dtype) }, )?; - res.as_base_value().into() + ValueEnum::Dynamic(mapped_ndarray.instance.value.as_basic_value_enum()) } else { unimplemented!() })) @@ -1876,85 +1850,46 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) }; - let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) }; + let (Some(left_ty), left) = left else { codegen_unreachable!(ctx) }; + let (Some(right_ty), right) = comparators[0] else { codegen_unreachable!(ctx) }; let op = ops[0]; - let is_ndarray1 = - left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let left = AnyObject { value: left, ty: left_ty }; + let left = + ScalarOrNDArray::split_object(generator, ctx, left).to_ndarray(generator, ctx); - return if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); + let right = AnyObject { value: right, ty: right_ty }; + let right = + ScalarOrNDArray::split_object(generator, ctx, right).to_ndarray(generator, ctx); - assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + let result_ndarray = NDArrayObject::broadcast_starmap( + generator, + ctx, + &[left, right], + NDArrayOut::NewNDArray { dtype: ctx.primitives.bool }, + |generator, ctx, scalars| { + let left_scalar = scalars[0]; + let right_scalar = scalars[1]; - let left_val = - NDArrayValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - (left_val.as_base_value().into(), false), - (rhs, false), - |generator, ctx, (lhs, rhs)| { - let val = gen_cmpop_expr_with_values( - generator, - ctx, - (Some(ndarray_dtype1), lhs), - &[op], - &[(Some(ndarray_dtype2), rhs)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, - )?; + let val = gen_cmpop_expr_with_values( + generator, + ctx, + (Some(left.dtype), left_scalar), + &[op], + &[(Some(right.dtype), right_scalar)], + )? + .unwrap() + .to_basic_value_enum( + ctx, + generator, + ctx.primitives.bool, + )?; - Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) - }, - )?; + Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) + }, + )?; - Ok(Some(res.as_base_value().into())) - } else { - let (ndarray_dtype, _) = unpack_ndarray_var_tys( - &mut ctx.unifier, - if is_ndarray1 { left_ty } else { right_ty }, - ); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - (lhs, !is_ndarray1), - (rhs, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - let val = gen_cmpop_expr_with_values( - generator, - ctx, - (Some(ndarray_dtype), lhs), - &[op], - &[(Some(ndarray_dtype), rhs)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, - )?; - - Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) - }, - )?; - - Ok(Some(res.as_base_value().into())) - }; + return Ok(Some(result_ndarray.instance.value.into())); } } diff --git a/nac3core/src/codegen/object/ndarray/map.rs b/nac3core/src/codegen/object/ndarray/map.rs new file mode 100644 index 0000000..6f034e5 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/map.rs @@ -0,0 +1,219 @@ +use inkwell::values::BasicValueEnum; +use itertools::Itertools; + +use crate::{ + codegen::{ + object::ndarray::{AnyObject, NDArrayObject}, + stmt::gen_for_callback, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::Type, +}; + +use super::{nditer::NDIterHandle, NDArrayOut, ScalarOrNDArray}; + +impl<'ctx> NDArrayObject<'ctx> { + /// Generate LLVM IR to broadcast `ndarray`s together, and starmap through them with `mapping` elementwise. + /// + /// `mapping` is an LLVM IR generator. The input of `mapping` is the list of elements when iterating through + /// the input `ndarrays` after broadcasting. The output of `mapping` is the result of the elementwise operation. + /// + /// `out` specifies whether the result should be a new ndarray or to be written an existing ndarray. + pub fn broadcast_starmap<'a, G, MappingFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ndarrays: &[Self], + out: NDArrayOut<'ctx>, + mapping: MappingFn, + ) -> Result + where + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &[BasicValueEnum<'ctx>], + ) -> Result, String>, + { + // Broadcast inputs + let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays); + + let out_ndarray = match out { + NDArrayOut::NewNDArray { dtype } => { + // Create a new ndarray based on the broadcast shape. + let result_ndarray = + NDArrayObject::alloca(generator, ctx, dtype, broadcast_result.ndims); + result_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape); + result_ndarray.create_data(generator, ctx); + result_ndarray + } + NDArrayOut::WriteToNDArray { ndarray: result_ndarray } => { + // Use an existing ndarray. + + // Check that its shape is compatible with the broadcast shape. + result_ndarray.assert_can_be_written_by_out( + generator, + ctx, + broadcast_result.ndims, + broadcast_result.shape, + ); + result_ndarray + } + }; + + // Map element-wise and store results into `mapped_ndarray`. + let nditer = NDIterHandle::new(generator, ctx, out_ndarray); + gen_for_callback( + generator, + ctx, + Some("broadcast_starmap"), + |generator, ctx| { + // Create NDIters for all broadcasted input ndarrays. + let other_nditers = broadcast_result + .ndarrays + .iter() + .map(|ndarray| NDIterHandle::new(generator, ctx, *ndarray)) + .collect_vec(); + Ok((nditer, other_nditers)) + }, + |generator, ctx, (out_nditer, _in_nditers)| { + // We can simply use `out_nditer`'s `has_element()`. + // `in_nditers`' `has_element()`s should return the same value. + Ok(out_nditer.has_element(generator, ctx).value) + }, + |generator, ctx, _hooks, (out_nditer, in_nditers)| { + // Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`, + // and write to `out_ndarray`. + let in_scalars = in_nditers + .iter() + .map(|nditer| nditer.get_scalar(generator, ctx).value) + .collect_vec(); + + let result = mapping(generator, ctx, &in_scalars)?; + + let p = out_nditer.get_pointer(generator, ctx); + ctx.builder.build_store(p, result).unwrap(); + + Ok(()) + }, + |generator, ctx, (out_nditer, in_nditers)| { + // Advance all iterators + out_nditer.next(generator, ctx); + in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx)); + Ok(()) + }, + )?; + + Ok(out_ndarray) + } + + /// Map through this ndarray with an elementwise function. + pub fn map<'a, G, Mapping>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + out: NDArrayOut<'ctx>, + mapping: Mapping, + ) -> Result + where + G: CodeGenerator + ?Sized, + Mapping: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, + { + NDArrayObject::broadcast_starmap( + generator, + ctx, + &[*self], + out, + |generator, ctx, scalars| mapping(generator, ctx, scalars[0]), + ) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// Starmap through a list of inputs using `mapping`, where an input could be an ndarray, a scalar. + /// + /// This function is very helpful when implementing NumPy functions that takes on either scalars or ndarrays or a mix of them + /// as their inputs and produces either an ndarray with broadcast, or a scalar if all its inputs are all scalars. + /// + /// For example ,this function can be used to implement `np.add`, which has the following behaviors: + /// - `np.add(3, 4) = 7` # (scalar, scalar) -> scalar + /// - `np.add(3, np.array([4, 5, 6]))` # (scalar, ndarray) -> ndarray; the first `scalar` is converted into an ndarray and broadcasted. + /// - `np.add(np.array([[1], [2], [3]]), np.array([[4, 5, 6]]))` # (ndarray, ndarray) -> ndarray; there is broadcasting. + /// + /// ## Details: + /// + /// If `inputs` are all [`ScalarOrNDArray::Scalar`], the output will be a [`ScalarOrNDArray::Scalar`] with type `ret_dtype`. + /// + /// Otherwise (if there are any [`ScalarOrNDArray::NDArray`] in `inputs`), all inputs will be 'as-ndarray'-ed into ndarrays, + /// then all inputs (now all ndarrays) will be passed to [`NDArrayObject::broadcasting_starmap`] and **create** a new ndarray + /// with dtype `ret_dtype`. + pub fn broadcasting_starmap<'a, G, MappingFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + inputs: &[ScalarOrNDArray<'ctx>], + ret_dtype: Type, + mapping: MappingFn, + ) -> Result, String> + where + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &[BasicValueEnum<'ctx>], + ) -> Result, String>, + { + // Check if all inputs are Scalars + let all_scalars: Option> = inputs.iter().map(AnyObject::try_from).try_collect().ok(); + + if let Some(scalars) = all_scalars { + let scalars = scalars.iter().map(|scalar| scalar.value).collect_vec(); + let value = mapping(generator, ctx, &scalars)?; + + Ok(ScalarOrNDArray::Scalar(AnyObject { ty: ret_dtype, value })) + } else { + // Promote all input to ndarrays and map through them. + let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec(); + let ndarray = NDArrayObject::broadcast_starmap( + generator, + ctx, + &inputs, + NDArrayOut::NewNDArray { dtype: ret_dtype }, + mapping, + )?; + Ok(ScalarOrNDArray::NDArray(ndarray)) + } + } + + /// Map through this [`ScalarOrNDArray`] with an elementwise function. + /// + /// If this is a scalar, `mapping` will directly act on the scalar. This function will return a [`ScalarOrNDArray::Scalar`] of that result. + /// + /// If this is an ndarray, `mapping` will be applied to the elements of the ndarray. A new ndarray of the results will be created and + /// returned as a [`ScalarOrNDArray::NDArray`]. + pub fn map<'a, G, Mapping>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ret_dtype: Type, + mapping: Mapping, + ) -> Result, String> + where + G: CodeGenerator + ?Sized, + Mapping: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, + { + ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[*self], + ret_dtype, + |generator, ctx, scalars| mapping(generator, ctx, scalars[0]), + ) + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 655f8e7..276ee3a 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -13,6 +13,7 @@ use crate::{ call_nac3_ndarray_get_pelement_by_indices, call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_len, call_nac3_ndarray_nbytes, call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size, + call_nac3_ndarray_util_assert_output_shape_same, }, model::*, CodeGenContext, CodeGenerator, @@ -25,6 +26,7 @@ pub mod array; pub mod broadcast; pub mod factory; pub mod indexing; +pub mod map; pub mod nditer; pub mod shape_util; pub mod view; @@ -499,6 +501,31 @@ impl<'ctx> NDArrayObject<'ctx> { ndarray.instance.set(ctx, |f| f.data, data); ndarray } + /// Check if this `NDArray` can be used as an `out` ndarray for an operation. + /// + /// Raise an exception if the shapes do not match. + pub fn assert_can_be_written_by_out( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + out_ndims: u64, + out_shape: Instance<'ctx, Ptr>>, + ) { + let ndarray_ndims = self.ndims_llvm(generator, ctx.ctx); + let ndarray_shape = self.instance.get(generator, ctx, |f| f.shape); + + let output_ndims = Int(SizeT).const_int(generator, ctx.ctx, out_ndims, false); + let output_shape = out_shape; + + call_nac3_ndarray_util_assert_output_shape_same( + generator, + ctx, + ndarray_ndims, + ndarray_shape, + output_ndims, + output_shape, + ); + } } /// A convenience enum for implementing functions that acts on scalars or ndarrays or both. @@ -584,3 +611,27 @@ impl<'ctx> ScalarOrNDArray<'ctx> { } } } + +/// An helper enum specifying how a function should produce its output. +/// +/// Many functions in NumPy has an optional `out` parameter (e.g., `matmul`). If `out` is specified +/// with an ndarray, the result of a function will be written to `out`. If `out` is not specified, a function will +/// create a new ndarray and store the result in it. +#[derive(Debug, Clone, Copy)] +pub enum NDArrayOut<'ctx> { + /// Tell a function should create a new ndarray with the expected element type `dtype`. + NewNDArray { dtype: Type }, + /// Tell a function to write the result to `ndarray`. + WriteToNDArray { ndarray: NDArrayObject<'ctx> }, +} + +impl<'ctx> NDArrayOut<'ctx> { + /// Get the dtype of this output. + #[must_use] + pub fn get_dtype(&self) -> Type { + match self { + NDArrayOut::NewNDArray { dtype } => *dtype, + NDArrayOut::WriteToNDArray { ndarray } => ndarray.dtype, + } + } +}