Fix SVD instability bug

This commit is contained in:
YuhanLiin 2022-03-09 02:13:12 -05:00
parent 757b99e843
commit 325618ba22
2 changed files with 31 additions and 15 deletions

View File

@ -2,7 +2,6 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::any::TypeId; use std::any::TypeId;
use approx::AbsDiffEq;
use num::{One, Zero}; use num::{One, Zero};
use crate::allocator::Allocator; use crate::allocator::Allocator;
@ -94,14 +93,7 @@ where
/// The singular values are not guaranteed to be sorted in any particular order. /// The singular values are not guaranteed to be sorted in any particular order.
/// If a descending order is required, consider using `new` instead. /// If a descending order is required, consider using `new` instead.
pub fn new_unordered(matrix: OMatrix<T, R, C>, compute_u: bool, compute_v: bool) -> Self { pub fn new_unordered(matrix: OMatrix<T, R, C>, compute_u: bool, compute_v: bool) -> Self {
Self::try_new_unordered( Self::try_new_unordered(matrix, compute_u, compute_v, crate::convert(1e-15), 0).unwrap()
matrix,
compute_u,
compute_v,
T::RealField::default_epsilon(),
0,
)
.unwrap()
} }
/// Attempts to compute the Singular Value Decomposition of `matrix` using implicit shift. /// Attempts to compute the Singular Value Decomposition of `matrix` using implicit shift.
@ -888,13 +880,13 @@ fn compute_2x2_uptrig_svd<T: RealField>(
v_t = Some(csv.clone()); v_t = Some(csv.clone());
} }
if compute_u {
let cu = (m11.scale(csv.c()) + m12 * csv.s()) / v1.clone(); let cu = (m11.scale(csv.c()) + m12 * csv.s()) / v1.clone();
let su = (m22 * csv.s()) / v1.clone(); let su = (m22 * csv.s()) / v1.clone();
let (csu, sgn_u) = GivensRotation::new(cu, su); let (csu, sgn_u) = GivensRotation::new(cu, su);
v1 *= sgn_u.clone(); v1 *= sgn_u.clone();
v2 *= sgn_u; v2 *= sgn_u;
if compute_u {
u = Some(csu); u = Some(csu);
} }
} }

View File

@ -460,3 +460,27 @@ fn svd_sorted() {
epsilon = 1.0e-5 epsilon = 1.0e-5
); );
} }
#[test]
// Exercises bug reported in issue #983 of nalgebra
fn svd_consistent() {
let m = nalgebra::dmatrix![
10.74785316637712f64, -5.994983325167452, -6.064492921857296;
-4.149751381521569, 20.654504205822462, -4.470436210703133;
-22.772715014220207, -1.4554372570788008, 18.108113992170573
]
.transpose();
let svd1 = m.clone().svd(true, true);
let svd2 = m.clone().svd(false, true);
let svd3 = m.clone().svd(true, false);
let svd4 = m.svd(false, false);
assert_relative_eq!(svd1.singular_values, svd2.singular_values, epsilon = 1e-5);
assert_relative_eq!(svd1.singular_values, svd3.singular_values, epsilon = 1e-5);
assert_relative_eq!(svd1.singular_values, svd4.singular_values, epsilon = 1e-5);
assert_relative_eq!(
svd1.singular_values,
nalgebra::dvector![3.16188022e+01, 2.23811978e+01, 0.],
epsilon = 1e-5
);
}