From ac94fbe831ab4756d691f1e1e749b2b27a67390d Mon Sep 17 00:00:00 2001 From: metric-space Date: Sun, 26 Dec 2021 21:01:05 -0500 Subject: [PATCH] Add polar decomposition method to main matrix decomposition interface Add one more test for decomposition of polar decomposition of rectangular matrix --- src/linalg/decomposition.rs | 37 +++++++++++++++++++++++++++++++++++-- src/linalg/svd.rs | 29 +++++++++++------------------ tests/linalg/svd.rs | 14 ++++++++++++-- 3 files changed, 58 insertions(+), 22 deletions(-) diff --git a/src/linalg/decomposition.rs b/src/linalg/decomposition.rs index 91ad03d9..d75cae9c 100644 --- a/src/linalg/decomposition.rs +++ b/src/linalg/decomposition.rs @@ -1,8 +1,8 @@ use crate::storage::Storage; use crate::{ Allocator, Bidiagonal, Cholesky, ColPivQR, ComplexField, DefaultAllocator, Dim, DimDiff, - DimMin, DimMinimum, DimSub, FullPivLU, Hessenberg, Matrix, RealField, Schur, SymmetricEigen, - SymmetricTridiagonal, LU, QR, SVD, U1, UDU, + DimMin, DimMinimum, DimSub, FullPivLU, Hessenberg, Matrix, OMatrix, RealField, Schur, + SymmetricEigen, SymmetricTridiagonal, LU, QR, SVD, U1, UDU, }; /// # Rectangular matrix decomposition @@ -17,6 +17,7 @@ use crate::{ /// | LU with partial pivoting | `P⁻¹ * L * U` | `L` is lower-triangular with a diagonal filled with `1` and `U` is upper-triangular. `P` is a permutation matrix. | /// | LU with full pivoting | `P⁻¹ * L * U * Q⁻¹` | `L` is lower-triangular with a diagonal filled with `1` and `U` is upper-triangular. `P` and `Q` are permutation matrices. | /// | SVD | `U * Σ * Vᵀ` | `U` and `V` are two orthogonal matrices and `Σ` is a diagonal matrix containing the singular values. | +/// | Polar (Left Polar) | `P' * U` | `U` is semi-unitary/unitary and `P'` is a positive semi-definite Hermitian Matrix impl> Matrix { /// Computes the bidiagonalization using householder reflections. pub fn bidiagonalize(self) -> Bidiagonal @@ -186,6 +187,38 @@ impl> Matrix { { SVD::try_new_unordered(self.into_owned(), compute_u, compute_v, eps, max_niter) } + + /// Attempts to compute the Polar Decomposition of a `matrix + /// + /// # Arguments + /// + /// * `eps` − tolerance used to determine when a value converged to 0. + /// * `max_niter` − maximum total number of iterations performed by the algorithm + pub fn polar( + self, + eps: T::RealField, + max_niter: usize, + ) -> Option<(OMatrix, OMatrix)> + where + R: DimMin, + DimMinimum: DimSub, // for Bidiagonal. + DefaultAllocator: Allocator + + Allocator, R> + + Allocator> + + Allocator + + Allocator, DimMinimum> + + Allocator + + Allocator + + Allocator, U1>> + + Allocator, C> + + Allocator> + + Allocator> + + Allocator> + + Allocator, U1>>, + { + SVD::try_new_unordered(self.into_owned(), true, true, eps, max_niter) + .and_then(|svd| svd.to_polar()) + } } /// # Square matrix decomposition diff --git a/src/linalg/svd.rs b/src/linalg/svd.rs index dabcb491..0e5d7f6c 100644 --- a/src/linalg/svd.rs +++ b/src/linalg/svd.rs @@ -641,33 +641,26 @@ where } } } -} -impl, C: Dim> SVD -where - DefaultAllocator: Allocator, C> - + Allocator> - + Allocator>, -{ - /// converts SVD results to a polar form - - pub fn to_polar(&self) -> Result<(OMatrix, OMatrix), &'static str> - where DefaultAllocator: Allocator //result + /// converts SVD results to Polar decomposition form of the original Matrix + /// A = P'U + /// The polar decomposition used here is Left Polar Decomposition (or Reverse Polar Decomposition) + /// Returns None if the SVD hasn't been calculated + pub fn to_polar(&self) -> Option<(OMatrix, OMatrix)> + where + DefaultAllocator: Allocator //result + Allocator, R> // adjoint + Allocator> // mapped vals - + Allocator // square matrix & result - + Allocator, DimMinimum> // ? - , + + Allocator // result + + Allocator, DimMinimum>, // square matrix { match (&self.u, &self.v_t) { - (Some(u), Some(v_t)) => Ok(( + (Some(u), Some(v_t)) => Some(( u * OMatrix::from_diagonal(&self.singular_values.map(|e| T::from_real(e))) * u.adjoint(), u * v_t, )), - (None, None) => Err("SVD solve: U and V^t have not been computed."), - (None, _) => Err("SVD solve: U has not been computed."), - (_, None) => Err("SVD solve: V^t has not been computed."), + _ => None, } } } diff --git a/tests/linalg/svd.rs b/tests/linalg/svd.rs index 65a92ddb..251156b5 100644 --- a/tests/linalg/svd.rs +++ b/tests/linalg/svd.rs @@ -443,12 +443,22 @@ fn svd_sorted() { } #[test] -fn svd_polar_decomposition() { +fn dynamic_square_matrix_polar_decomposition() { - let m = DMatrix::::new_random(4, 4); + let m = DMatrix::::new_random(10, 10); let svd = m.clone().svd(true, true); let (p,u) = svd.to_polar().unwrap(); assert_relative_eq!(m, p*u, epsilon = 1.0e-5); } + +#[test] +fn dynamic_rectangular_matrix_polar_decomposition() { + + let m = DMatrix::::new_random(7, 5); + let svd = m.clone().svd(true, true); + let (p,u) = svd.to_polar().unwrap(); + + assert_relative_eq!(m, p*u, epsilon = 1.0e-5); +}