From 8b1aa2078c2e382a2f7e54496e95cfcfd17a9691 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joa=CC=83o=20Costa?= Date: Mon, 8 Oct 2018 22:13:56 +0100 Subject: [PATCH] Change the SVD methods to return a Result instead of panicking --- src/linalg/svd.rs | 110 ++++++++++++++++++++++++-------------------- tests/linalg/svd.rs | 34 ++++++++------ 2 files changed, 80 insertions(+), 64 deletions(-) diff --git a/src/linalg/svd.rs b/src/linalg/svd.rs index 43e2946e..38caf255 100644 --- a/src/linalg/svd.rs +++ b/src/linalg/svd.rs @@ -485,89 +485,97 @@ where /// Rebuild the original matrix. /// - /// This is useful if some of the singular values have been manually modified. Panics if the - /// right- and left- singular vectors have not been computed at construction-time. - pub fn recompose(self) -> MatrixMN { - let mut u = self.u.expect("SVD recomposition: U has not been computed."); - let v_t = self.v_t - .expect("SVD recomposition: V^t has not been computed."); + /// This is useful if some of the singular values have been manually modified. + /// Returns `Err` if the right- and left- singular vectors have not been + /// computed at construction-time. + pub fn recompose(self) -> Result, &'static str> { + match (self.u, self.v_t) { + (Some(_u), Some(_v_t)) => { + let mut u = _u; + let v_t = _v_t; - for i in 0..self.singular_values.len() { - let val = self.singular_values[i]; - u.column_mut(i).mul_assign(val); + for i in 0..self.singular_values.len() { + let val = self.singular_values[i]; + u.column_mut(i).mul_assign(val); + } + Ok(u * v_t) + } + (None, None) => Err("SVD recomposition: U and V^t have not been computed."), + (None, _) => Err("SVD recomposition: U has not been computed."), + (_, None) => Err("SVD recomposition: V^t has not been computed.") } - - u * v_t } /// Computes the pseudo-inverse of the decomposed matrix. /// /// Any singular value smaller than `eps` is assumed to be zero. - /// Panics if the right- and left- singular vectors have not been computed at - /// construction-time. - pub fn pseudo_inverse(mut self, eps: N) -> MatrixMN + /// Returns `Err` if the right- and left- singular vectors have not + /// been computed at construction-time. + pub fn pseudo_inverse(mut self, eps: N) -> Result, &'static str> where DefaultAllocator: Allocator, { - assert!( - eps >= N::zero(), - "SVD pseudo inverse: the epsilon must be non-negative." - ); - for i in 0..self.singular_values.len() { - let val = self.singular_values[i]; - - if val > eps { - self.singular_values[i] = N::one() / val; - } else { - self.singular_values[i] = N::zero(); - } + if eps < N::zero() { + Err("SVD pseudo inverse: the epsilon must be non-negative.") } + else { + for i in 0..self.singular_values.len() { + let val = self.singular_values[i]; - self.recompose().transpose() + if val > eps { + self.singular_values[i] = N::one() / val; + } else { + self.singular_values[i] = N::zero(); + } + } + + self.recompose().map(|m| m.transpose()) + } } /// Solves the system `self * x = b` where `self` is the decomposed matrix and `x` the unknown. /// /// Any singular value smaller than `eps` is assumed to be zero. - /// Returns `None` if the singular vectors `U` and `V` have not been computed. + /// Returns `Err` if the singular vectors `U` and `V` have not been computed. // FIXME: make this more generic wrt the storage types and the dimensions for `b`. pub fn solve( &self, b: &Matrix, eps: N, - ) -> MatrixMN + ) -> Result, &'static str> where S2: Storage, DefaultAllocator: Allocator + Allocator, C2>, ShapeConstraint: SameNumberOfRows, { - assert!( - eps >= N::zero(), - "SVD solve: the epsilon must be non-negative." - ); - let u = self.u - .as_ref() - .expect("SVD solve: U has not been computed."); - let v_t = self.v_t - .as_ref() - .expect("SVD solve: V^t has not been computed."); + if eps < N::zero() { + Err("SVD solve: the epsilon must be non-negative.") + } + else { + match (&self.u, &self.v_t) { + (Some(u), Some(v_t)) => { + let mut ut_b = u.tr_mul(b); - let mut ut_b = u.tr_mul(b); + for j in 0..ut_b.ncols() { + let mut col = ut_b.column_mut(j); - for j in 0..ut_b.ncols() { - let mut col = ut_b.column_mut(j); + for i in 0..self.singular_values.len() { + let val = self.singular_values[i]; + if val > eps { + col[i] /= val; + } else { + col[i] = N::zero(); + } + } + } - for i in 0..self.singular_values.len() { - let val = self.singular_values[i]; - if val > eps { - col[i] /= val; - } else { - col[i] = N::zero(); + Ok(v_t.tr_mul(&ut_b)) } + (None, None) => Err("SVD solve: U and V^t have not been computed."), + (None, _) => Err("SVD solve: U has not been computed."), + (_, None) => Err("SVD solve: V^t has not been computed.") } } - - v_t.tr_mul(&ut_b) } } @@ -623,7 +631,7 @@ where /// Computes the pseudo-inverse of this matrix. /// /// All singular values below `eps` are considered equal to 0. - pub fn pseudo_inverse(self, eps: N) -> MatrixMN + pub fn pseudo_inverse(self, eps: N) -> Result, &'static str> where DefaultAllocator: Allocator, { diff --git a/tests/linalg/svd.rs b/tests/linalg/svd.rs index 5b09c7c8..06b7f135 100644 --- a/tests/linalg/svd.rs +++ b/tests/linalg/svd.rs @@ -9,7 +9,7 @@ mod quickcheck_tests { fn svd(m: DMatrix) -> bool { if m.len() > 0 { let svd = m.clone().svd(true, true); - let recomp_m = svd.clone().recompose(); + let recomp_m = svd.clone().recompose().unwrap(); let (u, s, v_t) = (svd.u.unwrap(), svd.singular_values, svd.v_t.unwrap()); let ds = DMatrix::from_diagonal(&s); @@ -90,7 +90,7 @@ mod quickcheck_tests { fn svd_pseudo_inverse(m: DMatrix) -> bool { if m.len() > 0 { let svd = m.clone().svd(true, true); - let pinv = svd.pseudo_inverse(1.0e-10); + let pinv = svd.pseudo_inverse(1.0e-10).unwrap(); if m.nrows() > m.ncols() { println!("{}", &pinv * &m); @@ -117,10 +117,10 @@ mod quickcheck_tests { let b1 = DVector::new_random(n); let b2 = DMatrix::new_random(n, nb); - let sol1 = svd.solve(&b1, 1.0e-7); - let sol2 = svd.solve(&b2, 1.0e-7); + let sol1 = svd.solve(&b1, 1.0e-7).unwrap(); + let sol2 = svd.solve(&b2, 1.0e-7).unwrap(); - let recomp = svd.recompose(); + let recomp = svd.recompose().unwrap(); if !relative_eq!(m, recomp, epsilon = 1.0e-6) { println!("{}{}", m, recomp); } @@ -262,22 +262,22 @@ fn svd_singular_horizontal() { fn svd_zeros() { let m = DMatrix::from_element(10, 10, 0.0); let svd = m.clone().svd(true, true); - assert_eq!(m, svd.recompose()); + assert_eq!(Ok(m), svd.recompose()); } #[test] fn svd_identity() { let m = DMatrix::::identity(10, 10); let svd = m.clone().svd(true, true); - assert_eq!(m, svd.recompose()); + assert_eq!(Ok(m), svd.recompose()); let m = DMatrix::::identity(10, 15); let svd = m.clone().svd(true, true); - assert_eq!(m, svd.recompose()); + assert_eq!(Ok(m), svd.recompose()); let m = DMatrix::::identity(15, 10); let svd = m.clone().svd(true, true); - assert_eq!(m, svd.recompose()); + assert_eq!(Ok(m), svd.recompose()); } #[test] @@ -294,7 +294,7 @@ fn svd_with_delimited_subproblem() { m[(8,8)] = 16.0; m[(3,9)] = 17.0; m[(9,9)] = 18.0; let svd = m.clone().svd(true, true); - assert!(relative_eq!(m, svd.recompose(), epsilon = 1.0e-7)); + assert!(relative_eq!(m, svd.recompose().unwrap(), epsilon = 1.0e-7)); // Rectangular versions. let mut m = DMatrix::::from_element(15, 10, 0.0); @@ -309,10 +309,10 @@ fn svd_with_delimited_subproblem() { m[(8,8)] = 16.0; m[(3,9)] = 17.0; m[(9,9)] = 18.0; let svd = m.clone().svd(true, true); - assert!(relative_eq!(m, svd.recompose(), epsilon = 1.0e-7)); + assert!(relative_eq!(m, svd.recompose().unwrap(), epsilon = 1.0e-7)); let svd = m.transpose().svd(true, true); - assert!(relative_eq!(m.transpose(), svd.recompose(), epsilon = 1.0e-7)); + assert!(relative_eq!(m.transpose(), svd.recompose().unwrap(), epsilon = 1.0e-7)); } #[test] @@ -328,7 +328,15 @@ fn svd_fail() { println!("Singular values: {}", svd.singular_values); println!("u: {:.5}", svd.u.unwrap()); println!("v: {:.5}", svd.v_t.unwrap()); - let recomp = svd.recompose(); + let recomp = svd.recompose().unwrap(); println!("{:.5}{:.5}", m, recomp); assert!(relative_eq!(m, recomp, epsilon = 1.0e-5)); } + +#[test] +fn svd_err() { + let m = DMatrix::from_element(10, 10, 0.0); + let svd = m.clone().svd(false, false); + assert_eq!(Err("SVD recomposition: U and V^t have not been computed."), svd.clone().recompose()); + assert_eq!(Err("SVD pseudo inverse: the epsilon must be non-negative."), svd.clone().pseudo_inverse(-1.0)); +}