core/ndstrides: add ScalarOrNDArray::get_dtype

This commit is contained in:
lyken 2024-08-20 17:15:44 +08:00
parent 0c6ba9fd6b
commit 11824e9420
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
1 changed files with 9 additions and 0 deletions

View File

@ -557,6 +557,15 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
ScalarOrNDArray::Scalar(scalar) => NDArrayObject::make_unsized(generator, ctx, *scalar), ScalarOrNDArray::Scalar(scalar) => NDArrayObject::make_unsized(generator, ctx, *scalar),
} }
} }
/// Get the dtype of the ndarray created if this were called with [`ScalarOrNDArray::to_ndarray`].
#[must_use]
pub fn get_dtype(&self) -> Type {
match self {
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
ScalarOrNDArray::Scalar(scalar) => scalar.ty,
}
}
} }
/// An helper enum specifying how a function should produce its output. /// An helper enum specifying how a function should produce its output.