core/ndstrides: add ScalarOrNDArray::get_dtype
This commit is contained in:
parent
f25fee1bbc
commit
cc32ed9a24
|
@ -557,6 +557,15 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
ScalarOrNDArray::Scalar(scalar) => NDArrayObject::make_unsized(generator, ctx, *scalar),
|
ScalarOrNDArray::Scalar(scalar) => NDArrayObject::make_unsized(generator, ctx, *scalar),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the dtype of the ndarray created if this were called with [`ScalarOrNDArray::to_ndarray`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_dtype(&self) -> Type {
|
||||||
|
match self {
|
||||||
|
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
|
||||||
|
ScalarOrNDArray::Scalar(scalar) => scalar.ty,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An helper enum specifying how a function should produce its output.
|
/// An helper enum specifying how a function should produce its output.
|
||||||
|
|
Loading…
Reference in New Issue