Implement DeviceCopy for UnitComplex, UnitQuaternion, and Unit<Matrix> instead of using a blanket impl

This commit is contained in:
Sébastien Crozet 2021-11-26 18:12:39 +01:00
parent e8c6a7c0a2
commit 9297cc5754
3 changed files with 21 additions and 4 deletions

View File

@ -26,10 +26,10 @@ use crate::{Dim, Matrix, OMatrix, RealField, Scalar, SimdComplexField, SimdRealF
/// in their documentation, read their dedicated pages directly. /// in their documentation, read their dedicated pages directly.
#[repr(transparent)] #[repr(transparent)]
#[derive(Clone, Hash, Copy)] #[derive(Clone, Hash, Copy)]
#[cfg_attr( // #[cfg_attr(
all(not(target_os = "cuda"), feature = "cuda"), // all(not(target_os = "cuda"), feature = "cuda"),
derive(cust::DeviceCopy) // derive(cust::DeviceCopy)
)] // )]
pub struct Unit<T> { pub struct Unit<T> {
pub(crate) value: T, pub(crate) value: T,
} }
@ -122,6 +122,17 @@ mod rkyv_impl {
} }
} }
#[cfg(all(not(target_os = "cuda"), feature = "cuda"))]
unsafe impl<T: cust::memory::DeviceCopy, R, C, S> cust::memory::DeviceCopy
for Unit<Matrix<T, R, C, S>>
where
T: Scalar,
R: Dim,
C: Dim,
S: RawStorage<T, R, C> + Copy,
{
}
impl<T, R, C, S> PartialEq for Unit<Matrix<T, R, C, S>> impl<T, R, C, S> PartialEq for Unit<Matrix<T, R, C, S>>
where where
T: Scalar + PartialEq, T: Scalar + PartialEq,

View File

@ -1068,6 +1068,9 @@ impl<T: RealField + fmt::Display> fmt::Display for Quaternion<T> {
/// A unit quaternions. May be used to represent a rotation. /// A unit quaternions. May be used to represent a rotation.
pub type UnitQuaternion<T> = Unit<Quaternion<T>>; pub type UnitQuaternion<T> = Unit<Quaternion<T>>;
#[cfg(all(not(target_os = "cuda"), feature = "cuda"))]
unsafe impl<T: cust::memory::DeviceCopy> cust::memory::DeviceCopy for UnitQuaternion<T> {}
impl<T: Scalar + ClosedNeg + PartialEq> PartialEq for UnitQuaternion<T> { impl<T: Scalar + ClosedNeg + PartialEq> PartialEq for UnitQuaternion<T> {
#[inline] #[inline]
fn eq(&self, rhs: &Self) -> bool { fn eq(&self, rhs: &Self) -> bool {

View File

@ -31,6 +31,9 @@ use std::cmp::{Eq, PartialEq};
/// * [Conversion to a matrix <span style="float:right;">`to_rotation_matrix`, `to_homogeneous`…</span>](#conversion-to-a-matrix) /// * [Conversion to a matrix <span style="float:right;">`to_rotation_matrix`, `to_homogeneous`…</span>](#conversion-to-a-matrix)
pub type UnitComplex<T> = Unit<Complex<T>>; pub type UnitComplex<T> = Unit<Complex<T>>;
#[cfg(all(not(target_os = "cuda"), feature = "cuda"))]
unsafe impl<T: cust::memory::DeviceCopy> cust::memory::DeviceCopy for UnitComplex<T> {}
impl<T: Scalar + PartialEq> PartialEq for UnitComplex<T> { impl<T: Scalar + PartialEq> PartialEq for UnitComplex<T> {
#[inline] #[inline]
fn eq(&self, rhs: &Self) -> bool { fn eq(&self, rhs: &Self) -> bool {