Fix the special-case for 3x3 Real SVD
This commit is contained in:
parent
a9890e2a2c
commit
e0a1b1bc34
|
@ -80,6 +80,16 @@ where
|
||||||
+ Allocator<T::RealField, DimMinimum<R, C>>
|
+ Allocator<T::RealField, DimMinimum<R, C>>
|
||||||
+ Allocator<T::RealField, DimDiff<DimMinimum<R, C>, U1>>,
|
+ Allocator<T::RealField, DimDiff<DimMinimum<R, C>, U1>>,
|
||||||
{
|
{
|
||||||
|
fn use_special_always_ordered_svd2() -> bool {
|
||||||
|
TypeId::of::<OMatrix<T, R, C>>() == TypeId::of::<Matrix2<T::RealField>>()
|
||||||
|
&& TypeId::of::<Self>() == TypeId::of::<SVD<T::RealField, U2, U2>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn use_special_always_ordered_svd3() -> bool {
|
||||||
|
TypeId::of::<OMatrix<T, R, C>>() == TypeId::of::<Matrix3<T::RealField>>()
|
||||||
|
&& TypeId::of::<Self>() == TypeId::of::<SVD<T::RealField, U3, U3>>()
|
||||||
|
}
|
||||||
|
|
||||||
/// Computes the Singular Value Decomposition of `matrix` using implicit shift.
|
/// Computes the Singular Value Decomposition of `matrix` using implicit shift.
|
||||||
/// The singular values are not guaranteed to be sorted in any particular order.
|
/// The singular values are not guaranteed to be sorted in any particular order.
|
||||||
/// If a descending order is required, consider using `new` instead.
|
/// If a descending order is required, consider using `new` instead.
|
||||||
|
@ -120,20 +130,16 @@ where
|
||||||
let (nrows, ncols) = matrix.shape_generic();
|
let (nrows, ncols) = matrix.shape_generic();
|
||||||
let min_nrows_ncols = nrows.min(ncols);
|
let min_nrows_ncols = nrows.min(ncols);
|
||||||
|
|
||||||
if TypeId::of::<OMatrix<T, R, C>>() == TypeId::of::<Matrix2<T::RealField>>()
|
if Self::use_special_always_ordered_svd2() {
|
||||||
&& TypeId::of::<Self>() == TypeId::of::<SVD<T::RealField, U2, U2>>()
|
|
||||||
{
|
|
||||||
// SAFETY: the reference transmutes are OK since we checked that the types match exactly.
|
// SAFETY: the reference transmutes are OK since we checked that the types match exactly.
|
||||||
let matrix: &Matrix2<T::RealField> = unsafe { std::mem::transmute(&matrix) };
|
let matrix: &Matrix2<T::RealField> = unsafe { std::mem::transmute(&matrix) };
|
||||||
let result = super::svd2::svd2(matrix, compute_u, compute_v);
|
let result = super::svd2::svd_ordered2(matrix, compute_u, compute_v);
|
||||||
let typed_result: &Self = unsafe { std::mem::transmute(&result) };
|
let typed_result: &Self = unsafe { std::mem::transmute(&result) };
|
||||||
return Some(typed_result.clone());
|
return Some(typed_result.clone());
|
||||||
} else if TypeId::of::<OMatrix<T, R, C>>() == TypeId::of::<Matrix3<T::RealField>>()
|
} else if Self::use_special_always_ordered_svd3() {
|
||||||
&& TypeId::of::<Self>() == TypeId::of::<SVD<T::RealField, U3, U3>>()
|
|
||||||
{
|
|
||||||
// SAFETY: the reference transmutes are OK since we checked that the types match exactly.
|
// SAFETY: the reference transmutes are OK since we checked that the types match exactly.
|
||||||
let matrix: &Matrix3<T::RealField> = unsafe { std::mem::transmute(&matrix) };
|
let matrix: &Matrix3<T::RealField> = unsafe { std::mem::transmute(&matrix) };
|
||||||
let result = super::svd3::svd3(matrix, compute_u, compute_v, eps, max_niter);
|
let result = super::svd3::svd_ordered3(matrix, compute_u, compute_v, eps, max_niter);
|
||||||
let typed_result: &Self = unsafe { std::mem::transmute(&result) };
|
let typed_result: &Self = unsafe { std::mem::transmute(&result) };
|
||||||
return Some(typed_result.clone());
|
return Some(typed_result.clone());
|
||||||
}
|
}
|
||||||
|
@ -657,7 +663,11 @@ where
|
||||||
/// If this order is not required consider using `new_unordered`.
|
/// If this order is not required consider using `new_unordered`.
|
||||||
pub fn new(matrix: OMatrix<T, R, C>, compute_u: bool, compute_v: bool) -> Self {
|
pub fn new(matrix: OMatrix<T, R, C>, compute_u: bool, compute_v: bool) -> Self {
|
||||||
let mut svd = Self::new_unordered(matrix, compute_u, compute_v);
|
let mut svd = Self::new_unordered(matrix, compute_u, compute_v);
|
||||||
svd.sort_by_singular_values();
|
|
||||||
|
if !Self::use_special_always_ordered_svd3() && !Self::use_special_always_ordered_svd2() {
|
||||||
|
svd.sort_by_singular_values();
|
||||||
|
}
|
||||||
|
|
||||||
svd
|
svd
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -681,7 +691,11 @@ where
|
||||||
max_niter: usize,
|
max_niter: usize,
|
||||||
) -> Option<Self> {
|
) -> Option<Self> {
|
||||||
Self::try_new_unordered(matrix, compute_u, compute_v, eps, max_niter).map(|mut svd| {
|
Self::try_new_unordered(matrix, compute_u, compute_v, eps, max_niter).map(|mut svd| {
|
||||||
svd.sort_by_singular_values();
|
if !Self::use_special_always_ordered_svd3() && !Self::use_special_always_ordered_svd2()
|
||||||
|
{
|
||||||
|
svd.sort_by_singular_values();
|
||||||
|
}
|
||||||
|
|
||||||
svd
|
svd
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,11 @@ use crate::{Matrix2, RealField, Vector2, SVD, U2};
|
||||||
|
|
||||||
// Implementation of the 2D SVD from https://ieeexplore.ieee.org/document/486688
|
// Implementation of the 2D SVD from https://ieeexplore.ieee.org/document/486688
|
||||||
// See also https://scicomp.stackexchange.com/questions/8899/robust-algorithm-for-2-times-2-svd
|
// See also https://scicomp.stackexchange.com/questions/8899/robust-algorithm-for-2-times-2-svd
|
||||||
pub fn svd2<T: RealField>(m: &Matrix2<T>, compute_u: bool, compute_v: bool) -> SVD<T, U2, U2> {
|
pub fn svd_ordered2<T: RealField>(
|
||||||
|
m: &Matrix2<T>,
|
||||||
|
compute_u: bool,
|
||||||
|
compute_v: bool,
|
||||||
|
) -> SVD<T, U2, U2> {
|
||||||
let half: T = crate::convert(0.5);
|
let half: T = crate::convert(0.5);
|
||||||
let one: T = crate::convert(1.0);
|
let one: T = crate::convert(1.0);
|
||||||
|
|
||||||
|
@ -12,6 +16,9 @@ pub fn svd2<T: RealField>(m: &Matrix2<T>, compute_u: bool, compute_v: bool) -> S
|
||||||
let h = (m.m21.clone() - m.m12.clone()) * half.clone();
|
let h = (m.m21.clone() - m.m12.clone()) * half.clone();
|
||||||
let q = (e.clone() * e.clone() + h.clone() * h.clone()).sqrt();
|
let q = (e.clone() * e.clone() + h.clone() * h.clone()).sqrt();
|
||||||
let r = (f.clone() * f.clone() + g.clone() * g.clone()).sqrt();
|
let r = (f.clone() * f.clone() + g.clone() * g.clone()).sqrt();
|
||||||
|
|
||||||
|
// Note that the singular values are always sorted because sx >= sy
|
||||||
|
// because q >= 0 and r >= 0.
|
||||||
let sx = q.clone() + r.clone();
|
let sx = q.clone() + r.clone();
|
||||||
let sy = q - r;
|
let sy = q - r;
|
||||||
let sy_sign = if sy < T::zero() { -one.clone() } else { one };
|
let sy_sign = if sy < T::zero() { -one.clone() } else { one };
|
||||||
|
|
|
@ -3,7 +3,10 @@ use simba::scalar::RealField;
|
||||||
|
|
||||||
// For the 3x3 case, on the GPU, it is much more efficient to compute the SVD
|
// For the 3x3 case, on the GPU, it is much more efficient to compute the SVD
|
||||||
// using an eigendecomposition followed by a QR decomposition.
|
// using an eigendecomposition followed by a QR decomposition.
|
||||||
pub fn svd3<T: RealField>(
|
//
|
||||||
|
// This is based on the paper "Computing the Singular Value Decomposition of 3 x 3 matrices with
|
||||||
|
// minimal branching and elementary floating point operations" from McAdams, et al.
|
||||||
|
pub fn svd_ordered3<T: RealField>(
|
||||||
m: &Matrix3<T>,
|
m: &Matrix3<T>,
|
||||||
compute_u: bool,
|
compute_u: bool,
|
||||||
compute_v: bool,
|
compute_v: bool,
|
||||||
|
@ -11,15 +14,42 @@ pub fn svd3<T: RealField>(
|
||||||
niter: usize,
|
niter: usize,
|
||||||
) -> Option<SVD<T, U3, U3>> {
|
) -> Option<SVD<T, U3, U3>> {
|
||||||
let s = m.tr_mul(&m);
|
let s = m.tr_mul(&m);
|
||||||
let v = s.try_symmetric_eigen(eps, niter)?.eigenvectors;
|
let mut v = s.try_symmetric_eigen(eps, niter)?.eigenvectors;
|
||||||
let b = m * &v;
|
let mut b = m * &v;
|
||||||
|
|
||||||
|
// Sort singular values. This is a necessary step to ensure that
|
||||||
|
// the QR decompositions R matrix ends up diagonal.
|
||||||
|
let mut rho0 = b.column(0).norm_squared();
|
||||||
|
let mut rho1 = b.column(1).norm_squared();
|
||||||
|
let mut rho2 = b.column(2).norm_squared();
|
||||||
|
|
||||||
|
if rho0 < rho1 {
|
||||||
|
b.swap_columns(0, 1);
|
||||||
|
b.column_mut(1).neg_mut();
|
||||||
|
v.swap_columns(0, 1);
|
||||||
|
v.column_mut(1).neg_mut();
|
||||||
|
std::mem::swap(&mut rho0, &mut rho1);
|
||||||
|
}
|
||||||
|
if rho0 < rho2 {
|
||||||
|
b.swap_columns(0, 2);
|
||||||
|
b.column_mut(2).neg_mut();
|
||||||
|
v.swap_columns(0, 2);
|
||||||
|
v.column_mut(2).neg_mut();
|
||||||
|
std::mem::swap(&mut rho0, &mut rho2);
|
||||||
|
}
|
||||||
|
if rho1 < rho2 {
|
||||||
|
b.swap_columns(1, 2);
|
||||||
|
b.column_mut(2).neg_mut();
|
||||||
|
v.swap_columns(1, 2);
|
||||||
|
v.column_mut(2).neg_mut();
|
||||||
|
std::mem::swap(&mut rho0, &mut rho2);
|
||||||
|
}
|
||||||
|
|
||||||
let qr = b.qr();
|
let qr = b.qr();
|
||||||
let singular_values = qr.diag_internal().map(|e| e.abs());
|
|
||||||
|
|
||||||
Some(SVD {
|
Some(SVD {
|
||||||
u: if compute_u { Some(qr.q()) } else { None },
|
u: if compute_u { Some(qr.q()) } else { None },
|
||||||
singular_values,
|
singular_values: qr.diag_internal().map(|e| e.abs()),
|
||||||
v_t: if compute_v { Some(v.transpose()) } else { None },
|
v_t: if compute_v { Some(v.transpose()) } else { None },
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@ mod proptest_tests {
|
||||||
prop_assert!(s.iter().all(|e| *e >= 0.0));
|
prop_assert!(s.iter().all(|e| *e >= 0.0));
|
||||||
prop_assert!(relative_eq!(&u * ds * &v_t, recomp_m, epsilon = 1.0e-5));
|
prop_assert!(relative_eq!(&u * ds * &v_t, recomp_m, epsilon = 1.0e-5));
|
||||||
prop_assert!(relative_eq!(m, recomp_m, epsilon = 1.0e-5));
|
prop_assert!(relative_eq!(m, recomp_m, epsilon = 1.0e-5));
|
||||||
|
prop_assert!(s.as_slice().windows(2).all(|elts| elts[0] >= elts[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -38,6 +39,7 @@ mod proptest_tests {
|
||||||
prop_assert!(relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5));
|
prop_assert!(relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5));
|
||||||
prop_assert!(u.is_orthogonal(1.0e-5));
|
prop_assert!(u.is_orthogonal(1.0e-5));
|
||||||
prop_assert!(v_t.is_orthogonal(1.0e-5));
|
prop_assert!(v_t.is_orthogonal(1.0e-5));
|
||||||
|
prop_assert!(s.as_slice().windows(2).all(|elts| elts[0] >= elts[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -50,6 +52,7 @@ mod proptest_tests {
|
||||||
prop_assert!(relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5));
|
prop_assert!(relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5));
|
||||||
prop_assert!(u.is_orthogonal(1.0e-5));
|
prop_assert!(u.is_orthogonal(1.0e-5));
|
||||||
prop_assert!(v_t.is_orthogonal(1.0e-5));
|
prop_assert!(v_t.is_orthogonal(1.0e-5));
|
||||||
|
prop_assert!(s.as_slice().windows(2).all(|elts| elts[0] >= elts[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -61,6 +64,7 @@ mod proptest_tests {
|
||||||
|
|
||||||
prop_assert!(s.iter().all(|e| *e >= 0.0));
|
prop_assert!(s.iter().all(|e| *e >= 0.0));
|
||||||
prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5));
|
prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5));
|
||||||
|
prop_assert!(s.as_slice().windows(2).all(|elts| elts[0] >= elts[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -71,6 +75,7 @@ mod proptest_tests {
|
||||||
|
|
||||||
prop_assert!(s.iter().all(|e| *e >= 0.0));
|
prop_assert!(s.iter().all(|e| *e >= 0.0));
|
||||||
prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5));
|
prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5));
|
||||||
|
prop_assert!(s.as_slice().windows(2).all(|elts| elts[0] >= elts[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -83,6 +88,7 @@ mod proptest_tests {
|
||||||
prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5));
|
prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5));
|
||||||
prop_assert!(u.is_orthogonal(1.0e-5));
|
prop_assert!(u.is_orthogonal(1.0e-5));
|
||||||
prop_assert!(v_t.is_orthogonal(1.0e-5));
|
prop_assert!(v_t.is_orthogonal(1.0e-5));
|
||||||
|
prop_assert!(s.as_slice().windows(2).all(|elts| elts[0] >= elts[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -95,6 +101,7 @@ mod proptest_tests {
|
||||||
prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5));
|
prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5));
|
||||||
prop_assert!(u.is_orthogonal(1.0e-5));
|
prop_assert!(u.is_orthogonal(1.0e-5));
|
||||||
prop_assert!(v_t.is_orthogonal(1.0e-5));
|
prop_assert!(v_t.is_orthogonal(1.0e-5));
|
||||||
|
prop_assert!(s.as_slice().windows(2).all(|elts| elts[0] >= elts[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -107,6 +114,7 @@ mod proptest_tests {
|
||||||
prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5));
|
prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5));
|
||||||
prop_assert!(u.is_orthogonal(1.0e-5));
|
prop_assert!(u.is_orthogonal(1.0e-5));
|
||||||
prop_assert!(v_t.is_orthogonal(1.0e-5));
|
prop_assert!(v_t.is_orthogonal(1.0e-5));
|
||||||
|
prop_assert!(s.as_slice().windows(2).all(|elts| elts[0] >= elts[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -187,6 +195,7 @@ fn svd_singular() {
|
||||||
let ds = DMatrix::from_diagonal(&s);
|
let ds = DMatrix::from_diagonal(&s);
|
||||||
|
|
||||||
assert!(s.iter().all(|e| *e >= 0.0));
|
assert!(s.iter().all(|e| *e >= 0.0));
|
||||||
|
assert!(s.as_slice().windows(2).all(|elts| elts[0] >= elts[1]));
|
||||||
assert!(u.is_orthogonal(1.0e-5));
|
assert!(u.is_orthogonal(1.0e-5));
|
||||||
assert!(v_t.is_orthogonal(1.0e-5));
|
assert!(v_t.is_orthogonal(1.0e-5));
|
||||||
assert_relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5);
|
assert_relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5);
|
||||||
|
@ -229,6 +238,7 @@ fn svd_singular_vertical() {
|
||||||
let ds = DMatrix::from_diagonal(&s);
|
let ds = DMatrix::from_diagonal(&s);
|
||||||
|
|
||||||
assert!(s.iter().all(|e| *e >= 0.0));
|
assert!(s.iter().all(|e| *e >= 0.0));
|
||||||
|
assert!(s.as_slice().windows(2).all(|elts| elts[0] >= elts[1]));
|
||||||
assert_relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5);
|
assert_relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -267,6 +277,7 @@ fn svd_singular_horizontal() {
|
||||||
let ds = DMatrix::from_diagonal(&s);
|
let ds = DMatrix::from_diagonal(&s);
|
||||||
|
|
||||||
assert!(s.iter().all(|e| *e >= 0.0));
|
assert!(s.iter().all(|e| *e >= 0.0));
|
||||||
|
assert!(s.as_slice().windows(2).all(|elts| elts[0] >= elts[1]));
|
||||||
assert_relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5);
|
assert_relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -350,6 +361,26 @@ fn svd_fail() {
|
||||||
assert_relative_eq!(m, recomp, epsilon = 1.0e-5);
|
assert_relative_eq!(m, recomp, epsilon = 1.0e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[rustfmt::skip]
|
||||||
|
fn svd3_fail() {
|
||||||
|
let m = nalgebra::matrix![
|
||||||
|
0.0, 1.0, 0.0;
|
||||||
|
0.0, 1.7320508075688772, 0.0;
|
||||||
|
0.0, 0.0, 0.0
|
||||||
|
];
|
||||||
|
|
||||||
|
// Check unordered ...
|
||||||
|
let svd = m.svd_unordered(true, true);
|
||||||
|
let recomp = svd.recompose().unwrap();
|
||||||
|
assert_relative_eq!(m, recomp, epsilon = 1.0e-5);
|
||||||
|
|
||||||
|
// ... and ordered SVD.
|
||||||
|
let svd = m.svd(true, true);
|
||||||
|
let recomp = svd.recompose().unwrap();
|
||||||
|
assert_relative_eq!(m, recomp, epsilon = 1.0e-5);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn svd_err() {
|
fn svd_err() {
|
||||||
let m = DMatrix::from_element(10, 10, 0.0);
|
let m = DMatrix::from_element(10, 10, 0.0);
|
||||||
|
|
Loading…
Reference in New Issue