From cc32ed9a24cc2e628f8a633c41b1a5adce53aaf8 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 17:15:44 +0800 Subject: [PATCH] core/ndstrides: add ScalarOrNDArray::get_dtype --- nac3core/src/codegen/object/ndarray/mod.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index d8b0693d..8582c8c2 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -557,6 +557,15 @@ impl<'ctx> ScalarOrNDArray<'ctx> { ScalarOrNDArray::Scalar(scalar) => NDArrayObject::make_unsized(generator, ctx, *scalar), } } + + /// Get the dtype of the ndarray created if this were called with [`ScalarOrNDArray::to_ndarray`]. + #[must_use] + pub fn get_dtype(&self) -> Type { + match self { + ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype, + ScalarOrNDArray::Scalar(scalar) => scalar.ty, + } + } } /// An helper enum specifying how a function should produce its output.