core/ndstrides: add ScalarOrNDArray::get_dtype
This commit is contained in:
parent
0c6ba9fd6b
commit
11824e9420
|
@ -557,6 +557,15 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
|||
ScalarOrNDArray::Scalar(scalar) => NDArrayObject::make_unsized(generator, ctx, *scalar),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the dtype of the ndarray created if this were called with [`ScalarOrNDArray::to_ndarray`].
|
||||
#[must_use]
|
||||
pub fn get_dtype(&self) -> Type {
|
||||
match self {
|
||||
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
|
||||
ScalarOrNDArray::Scalar(scalar) => scalar.ty,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An helper enum specifying how a function should produce its output.
|
||||
|
|
Loading…
Reference in New Issue