From 2f0d95bdbb3ceaa34428e0f6d9d9512f75f92afa Mon Sep 17 00:00:00 2001 From: sebcrozet Date: Tue, 19 Mar 2019 12:00:10 +0100 Subject: [PATCH] Fix most tests. --- src/base/matrix.rs | 8 +++++ src/geometry/reflection.rs | 39 +++++++++++++++++++++ src/linalg/bidiagonal.rs | 22 +++++++++--- src/linalg/hessenberg.rs | 6 ++-- src/linalg/householder.rs | 14 ++++---- src/linalg/qr.rs | 13 +++---- src/linalg/svd.rs | 6 ++-- src/linalg/symmetric_tridiagonal.rs | 22 ++++++------ tests/linalg/bidiagonal.rs | 1 - tests/linalg/svd.rs | 53 +++++++++++++++-------------- 10 files changed, 124 insertions(+), 60 deletions(-) diff --git a/src/base/matrix.rs b/src/base/matrix.rs index e7cf0026..1e402d29 100644 --- a/src/base/matrix.rs +++ b/src/base/matrix.rs @@ -770,6 +770,14 @@ impl> Matrix { } } + // FIXME: rename `apply` to `apply_mut` and `apply_into` to `apply`? + /// Returns `self` with each of its components replaced by the result of a closure `f` applied on it. + #[inline] + pub fn apply_into N>(mut self, mut f: F) -> Self{ + self.apply(f); + self + } + /// Replaces each component of `self` by the result of a closure `f` applied on it. #[inline] pub fn apply N>(&mut self, mut f: F) { diff --git a/src/geometry/reflection.rs b/src/geometry/reflection.rs index 237054b6..b3c77bd9 100644 --- a/src/geometry/reflection.rs +++ b/src/geometry/reflection.rs @@ -61,6 +61,23 @@ impl> Reflection { } } + // FIXME: naming convention: reflect_to, reflect_assign ? + /// Applies the reflection to the columns of `rhs`. + pub fn reflect_with_sign(&self, rhs: &mut Matrix, sign: N) + where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, + { + for i in 0..rhs.ncols() { + // NOTE: we borrow the column twice here. First it is borrowed immutably for the + // dot product, and then mutably. Somehow, this allows significantly + // better optimizations of the dot product from the compiler. + let m_two = sign.scale(::convert(-2.0f64)); + let factor = (self.axis.cdot(&rhs.column(i)) - self.bias) * m_two; + rhs.column_mut(i).axpy(factor, &self.axis, sign); + } + } + /// Applies the reflection to the rows of `lhs`. pub fn reflect_rows( &self, @@ -81,4 +98,26 @@ impl> Reflection { let m_two: N = ::convert(-2.0f64); lhs.ger(m_two, &work, &self.axis.conjugate(), N::one()); } + + /// Applies the reflection to the rows of `lhs`. + pub fn reflect_rows_with_sign( + &self, + lhs: &mut Matrix, + work: &mut Vector, + sign: N, + ) where + S2: StorageMut, + S3: StorageMut, + ShapeConstraint: DimEq + AreMultipliable, + DefaultAllocator: Allocator + { + lhs.mul_to(&self.axis, work); + + if !self.bias.is_zero() { + work.add_scalar_mut(-self.bias); + } + + let m_two = sign.scale(::convert(-2.0f64)); + lhs.ger(m_two, &work, &self.axis.conjugate(), sign); + } } diff --git a/src/linalg/bidiagonal.rs b/src/linalg/bidiagonal.rs index e487758c..51087e96 100644 --- a/src/linalg/bidiagonal.rs +++ b/src/linalg/bidiagonal.rs @@ -200,11 +200,11 @@ where let d = nrows.min(ncols); let mut res = MatrixN::identity_generic(d, d); - res.set_diagonal(&self.diagonal); + res.set_diagonal(&self.diagonal.map(|e| N::from_real(e.modulus()))); let start = self.axis_shift(); res.slice_mut(start, (d.value() - 1, d.value() - 1)) - .set_diagonal(&self.off_diagonal); + .set_diagonal(&self.off_diagonal.map(|e| N::from_real(e.modulus()))); res } @@ -225,7 +225,14 @@ where let refl = Reflection::new(Unit::new_unchecked(axis), N::zero()); let mut res_rows = res.slice_range_mut(i + shift.., i..); - refl.reflect(&mut res_rows); + + let sign = if self.upper_diagonal { + self.diagonal[i].signum() + } else { + self.off_diagonal[i].signum() + }; + + refl.reflect_with_sign(&mut res_rows, sign); } res @@ -251,7 +258,14 @@ where let refl = Reflection::new(Unit::new_unchecked(axis_packed), N::zero()); let mut res_rows = res.slice_range_mut(i.., i + shift..); - refl.reflect_rows(&mut res_rows, &mut work.rows_range_mut(i..)); + + let sign = if self.upper_diagonal { + self.off_diagonal[i].signum() + } else { + self.diagonal[i].signum() + }; + + refl.reflect_rows_with_sign(&mut res_rows, &mut work.rows_range_mut(i..), sign); } res diff --git a/src/linalg/hessenberg.rs b/src/linalg/hessenberg.rs index 25ab445b..a73ee5b5 100644 --- a/src/linalg/hessenberg.rs +++ b/src/linalg/hessenberg.rs @@ -107,7 +107,7 @@ where DefaultAllocator: Allocator + Allocator + Allocator + Allocator + Allocator MatrixN { - householder::assemble_q(&self.hess) + householder::assemble_q(&self.hess, self.subdiag.as_slice()) } #[doc(hidden)] diff --git a/src/linalg/householder.rs b/src/linalg/householder.rs index dc97b9b3..d04da188 100644 --- a/src/linalg/householder.rs +++ b/src/linalg/householder.rs @@ -60,10 +60,11 @@ pub fn clear_column_unchecked( if not_zero { let refl = Reflection::new(Unit::new_unchecked(axis), N::zero()); + let sign = reflection_norm.signum(); if let Some(mut work) = bilateral { - refl.reflect_rows(&mut right, &mut work); + refl.reflect_rows_with_sign(&mut right, &mut work, sign); } - refl.reflect(&mut right.rows_range_mut(icol + shift..)); + refl.reflect_with_sign(&mut right.rows_range_mut(icol + shift..), sign.conjugate()); } } @@ -90,9 +91,10 @@ pub fn clear_row_unchecked( if not_zero { let refl = Reflection::new(Unit::new_unchecked(axis), N::zero()); - refl.reflect_rows( + refl.reflect_rows_with_sign( &mut bottom.columns_range_mut(irow + shift..), &mut work.rows_range_mut(irow + 1..), + reflection_norm.signum().conjugate(), ); top.columns_range_mut(irow + shift..) .tr_copy_from(&refl.axis()); @@ -101,11 +103,11 @@ pub fn clear_row_unchecked( } } -/// Computes the orthogonal transformation described by the elementary reflector axices stored on +/// Computes the orthogonal transformation described by the elementary reflector axii stored on /// the lower-diagonal element of the given matrix. /// matrices. #[doc(hidden)] -pub fn assemble_q(m: &MatrixN) -> MatrixN +pub fn assemble_q(m: &MatrixN, signs: &[N]) -> MatrixN where DefaultAllocator: Allocator { assert!(m.is_square()); let dim = m.data.shape().0; @@ -119,7 +121,7 @@ where DefaultAllocator: Allocator { let refl = Reflection::new(Unit::new_unchecked(axis), N::zero()); let mut res_rows = res.slice_range_mut(i + 1.., i..); - refl.reflect(&mut res_rows); + refl.reflect_with_sign(&mut res_rows, signs[i].signum()); } res diff --git a/src/linalg/qr.rs b/src/linalg/qr.rs index 023a5042..e48c057c 100644 --- a/src/linalg/qr.rs +++ b/src/linalg/qr.rs @@ -1,5 +1,6 @@ #[cfg(feature = "serde-serialize")] use serde::{Deserialize, Serialize}; +use num::Zero; use alga::general::Complex; use allocator::{Allocator, Reallocator}; @@ -83,7 +84,7 @@ where DefaultAllocator: Allocator + Allocator + Allocator + Allocator + Allocator + Allocator + Allocator + Allocator + Allocator + Allocator let coeff; unsafe { - let diag = *self.diag.vget_unchecked(i); + let diag = self.diag.vget_unchecked(i).modulus(); if diag.is_zero() { return false; } - coeff = *b.vget_unchecked(i) / diag; + coeff = b.vget_unchecked(i).unscale(diag); *b.vget_unchecked_mut(i) = coeff; } diff --git a/src/linalg/svd.rs b/src/linalg/svd.rs index bcc93333..852878f7 100644 --- a/src/linalg/svd.rs +++ b/src/linalg/svd.rs @@ -543,7 +543,7 @@ where } } - self.recompose().map(|m| m.transpose()) + self.recompose().map(|m| m.conjugate_transpose()) } } @@ -568,7 +568,7 @@ where else { match (&self.u, &self.v_t) { (Some(u), Some(v_t)) => { - let mut ut_b = u.tr_mul(b); + let mut ut_b = u.conjugate().tr_mul(b); for j in 0..ut_b.ncols() { let mut col = ut_b.column_mut(j); @@ -583,7 +583,7 @@ where } } - Ok(v_t.tr_mul(&ut_b)) + Ok(v_t.conjugate().tr_mul(&ut_b)) } (None, None) => Err("SVD solve: U and V^t have not been computed."), (None, _) => Err("SVD solve: U has not been computed."), diff --git a/src/linalg/symmetric_tridiagonal.rs b/src/linalg/symmetric_tridiagonal.rs index cd4f7c3d..8fd36003 100644 --- a/src/linalg/symmetric_tridiagonal.rs +++ b/src/linalg/symmetric_tridiagonal.rs @@ -79,6 +79,7 @@ where DefaultAllocator: Allocator + Allocator> p.cgemv_symm(::convert(2.0), &m, &axis, N::zero()); let dot = axis.cdot(&p); + // p.axpy(-dot, &axis.conjugate(), N::one()); m.ger_symm(-N::one(), &p, &axis.conjugate(), N::one()); m.ger_symm(-N::one(), &axis, &p.conjugate(), N::one()); @@ -106,32 +107,30 @@ where DefaultAllocator: Allocator + Allocator> let diag = self.diagonal(); let q = self.q(); - (q, diag, self.off_diagonal) + (q, diag, self.off_diagonal.apply_into(|e| N::from_real(e.modulus()))) } /// Retrieve the diagonal, and off diagonal elements of this decomposition. - pub fn unpack_tridiagonal(self) -> (VectorN, VectorN>) + pub fn unpack_tridiagonal(mut self) -> (VectorN, VectorN>) where DefaultAllocator: Allocator { let diag = self.diagonal(); - (diag, self.off_diagonal) + (diag, self.off_diagonal.apply_into(|e| N::from_real(e.modulus()))) } /// The diagonal components of this decomposition. pub fn diagonal(&self) -> VectorN - where DefaultAllocator: Allocator { - self.tri.diagonal() - } + where DefaultAllocator: Allocator { self.tri.diagonal() } /// The off-diagonal components of this decomposition. - pub fn off_diagonal(&self) -> &VectorN> + pub fn off_diagonal(&self) -> VectorN> where DefaultAllocator: Allocator { - &self.off_diagonal + self.off_diagonal.map(|e| N::from_real(e.modulus())) } /// Computes the orthogonal matrix `Q` of this decomposition. pub fn q(&self) -> MatrixN { - householder::assemble_q(&self.tri) + householder::assemble_q(&self.tri, self.off_diagonal.as_slice()) } /// Recomputes the original symmetric matrix. @@ -141,8 +140,9 @@ where DefaultAllocator: Allocator + Allocator> self.tri.fill_upper_triangle(N::zero(), 2); for i in 0..self.off_diagonal.len() { - self.tri[(i + 1, i)] = self.off_diagonal[i]; - self.tri[(i, i + 1)] = self.off_diagonal[i].conjugate(); + let val = N::from_real(self.off_diagonal[i].modulus()); + self.tri[(i + 1, i)] = val; + self.tri[(i, i + 1)] = val; } &q * self.tri * q.conjugate_transpose() diff --git a/tests/linalg/bidiagonal.rs b/tests/linalg/bidiagonal.rs index dbb0c4fb..28b1e3a9 100644 --- a/tests/linalg/bidiagonal.rs +++ b/tests/linalg/bidiagonal.rs @@ -41,7 +41,6 @@ quickcheck! { relative_eq!(m, &u * d * &v_t, epsilon = 1.0e-7) } - fn bidiagonal_static_square(m: Matrix4>) -> bool { let m = m.map(|e| e.0); let bidiagonal = m.bidiagonalize(); diff --git a/tests/linalg/svd.rs b/tests/linalg/svd.rs index 91f2002d..860c7ce2 100644 --- a/tests/linalg/svd.rs +++ b/tests/linalg/svd.rs @@ -5,14 +5,15 @@ use na::{DMatrix, Matrix6}; mod quickcheck_tests { use na::{ DMatrix, DVector, Matrix2, Matrix2x5, Matrix3, Matrix3x5, Matrix4, Matrix5x2, Matrix5x3, + Complex }; use std::cmp; use core::helper::{RandScalar, RandComplex}; quickcheck! { - /* - fn svd(m: DMatrix) -> bool { + fn svd(m: DMatrix>) -> bool { + let m = m.map(|e| e.0); if m.len() > 0 { let svd = m.clone().svd(true, true); let recomp_m = svd.clone().recompose().unwrap(); @@ -21,7 +22,7 @@ mod quickcheck_tests { println!("{}{}", &m, &u * &ds * &v_t); - s.iter().all(|e| *e >= 0.0) && + s.iter().all(|e| e.real() >= 0.0) && relative_eq!(&u * ds * &v_t, recomp_m, epsilon = 1.0e-5) && relative_eq!(m, recomp_m, epsilon = 1.0e-5) } @@ -30,60 +31,62 @@ mod quickcheck_tests { } } - fn svd_static_5_3(m: Matrix5x3) -> bool { + fn svd_static_5_3(m: Matrix5x3>) -> bool { + let m = m.map(|e| e.0); let svd = m.svd(true, true); let (u, s, v_t) = (svd.u.unwrap(), svd.singular_values, svd.v_t.unwrap()); let ds = Matrix3::from_diagonal(&s); - s.iter().all(|e| *e >= 0.0) && + s.iter().all(|e| e.real() >= 0.0) && relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5) && u.is_orthogonal(1.0e-5) && v_t.is_orthogonal(1.0e-5) } - fn svd_static_5_2(m: Matrix5x2) -> bool { + fn svd_static_5_2(m: Matrix5x2>) -> bool { + let m = m.map(|e| e.0); let svd = m.svd(true, true); let (u, s, v_t) = (svd.u.unwrap(), svd.singular_values, svd.v_t.unwrap()); let ds = Matrix2::from_diagonal(&s); - s.iter().all(|e| *e >= 0.0) && + s.iter().all(|e| e.real() >= 0.0) && relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5) && u.is_orthogonal(1.0e-5) && v_t.is_orthogonal(1.0e-5) } - fn svd_static_3_5(m: Matrix3x5) -> bool { + fn svd_static_3_5(m: Matrix3x5>) -> bool { + let m = m.map(|e| e.0); let svd = m.svd(true, true); let (u, s, v_t) = (svd.u.unwrap(), svd.singular_values, svd.v_t.unwrap()); let ds = Matrix3::from_diagonal(&s); - s.iter().all(|e| *e >= 0.0) && + s.iter().all(|e| e.real() >= 0.0) && relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5) } - fn svd_static_2_5(m: Matrix2x5) -> bool { + fn svd_static_2_5(m: Matrix2x5>) -> bool { + let m = m.map(|e| e.0); let svd = m.svd(true, true); let (u, s, v_t) = (svd.u.unwrap(), svd.singular_values, svd.v_t.unwrap()); let ds = Matrix2::from_diagonal(&s); - s.iter().all(|e| *e >= 0.0) && + s.iter().all(|e| e.real() >= 0.0) && relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5) } - - fn svd_static_square(m: Matrix4) -> bool { + fn svd_static_square(m: Matrix4>) -> bool { + let m = m.map(|e| e.0); let svd = m.svd(true, true); let (u, s, v_t) = (svd.u.unwrap(), svd.singular_values, svd.v_t.unwrap()); let ds = Matrix4::from_diagonal(&s); - s.iter().all(|e| *e >= 0.0) && + s.iter().all(|e| e.real() >= 0.0) && relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5) && u.is_orthogonal(1.0e-5) && v_t.is_orthogonal(1.0e-5) } - */ - fn svd_static_square_2x2(m: Matrix2>) -> bool { let m = m.map(|e| e.0); @@ -102,8 +105,9 @@ mod quickcheck_tests { v_t.is_orthogonal(1.0e-5) } -/* - fn svd_pseudo_inverse(m: DMatrix) -> bool { + fn svd_pseudo_inverse(m: DMatrix>) -> bool { + let m = m.map(|e| e.0); + if m.len() > 0 { let svd = m.clone().svd(true, true); let pinv = svd.pseudo_inverse(1.0e-10).unwrap(); @@ -125,13 +129,13 @@ mod quickcheck_tests { fn svd_solve(n: usize, nb: usize) -> bool { let n = cmp::max(1, cmp::min(n, 10)); let nb = cmp::min(nb, 10); - let m = DMatrix::::new_random(n, n); + let m = DMatrix::>::new_random(n, n).map(|e| e.0); let svd = m.clone().svd(true, true); if svd.rank(1.0e-7) == n { - let b1 = DVector::new_random(n); - let b2 = DMatrix::new_random(n, nb); + let b1 = DVector::>::new_random(n).map(|e| e.0); + let b2 = DMatrix::>::new_random(n, nb).map(|e| e.0); let sol1 = svd.solve(&b1, 1.0e-7).unwrap(); let sol2 = svd.solve(&b2, 1.0e-7).unwrap(); @@ -153,11 +157,10 @@ mod quickcheck_tests { true } - */ } } -/* + // Test proposed on the issue #176 of rulinalg. #[test] fn svd_singular() { @@ -356,6 +359,4 @@ fn svd_err() { let svd = m.clone().svd(false, false); assert_eq!(Err("SVD recomposition: U and V^t have not been computed."), svd.clone().recompose()); assert_eq!(Err("SVD pseudo inverse: the epsilon must be non-negative."), svd.clone().pseudo_inverse(-1.0)); -} - -*/ \ No newline at end of file +} \ No newline at end of file