diff --git a/benches/linalg/svd.rs b/benches/linalg/svd.rs index b84f60d6..c95e44de 100644 --- a/benches/linalg/svd.rs +++ b/benches/linalg/svd.rs @@ -1,4 +1,18 @@ -use na::{Matrix4, SVD}; +use na::{Matrix2, Matrix3, Matrix4, SVD}; + +fn svd_decompose_2x2_f32(bh: &mut criterion::Criterion) { + let m = Matrix2::::new_random(); + bh.bench_function("svd_decompose_2x2", move |bh| { + bh.iter(|| std::hint::black_box(SVD::new_unordered(m.clone(), true, true))) + }); +} + +fn svd_decompose_3x3_f32(bh: &mut criterion::Criterion) { + let m = Matrix3::::new_random(); + bh.bench_function("svd_decompose_3x3", move |bh| { + bh.iter(|| std::hint::black_box(SVD::new_unordered(m.clone(), true, true))) + }); +} fn svd_decompose_4x4(bh: &mut criterion::Criterion) { let m = Matrix4::::new_random(); @@ -114,6 +128,8 @@ fn pseudo_inverse_200x200(bh: &mut criterion::Criterion) { criterion_group!( svd, + svd_decompose_2x2_f32, + svd_decompose_3x3_f32, svd_decompose_4x4, svd_decompose_10x10, svd_decompose_100x100, diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 0c272494..0209f9b1 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -24,6 +24,8 @@ mod qr; mod schur; mod solve; mod svd; +mod svd2; +mod svd3; mod symmetric_eigen; mod symmetric_tridiagonal; mod udu; diff --git a/src/linalg/qr.rs b/src/linalg/qr.rs index 5839f270..1b06e34b 100644 --- a/src/linalg/qr.rs +++ b/src/linalg/qr.rs @@ -147,6 +147,11 @@ where &self.qr } + #[must_use] + pub(crate) fn diag_internal(&self) -> &OVector> { + &self.diag + } + /// Multiplies the provided matrix by the transpose of the `Q` matrix of this decomposition. pub fn q_tr_mul(&self, rhs: &mut Matrix) // TODO: do we need a static constraint on the number of rows of rhs? diff --git a/src/linalg/svd.rs b/src/linalg/svd.rs index 00ee1a41..a2c1aecb 100644 --- a/src/linalg/svd.rs +++ b/src/linalg/svd.rs @@ -1,5 +1,6 @@ #[cfg(feature = "serde-serialize-no-std")] use serde::{Deserialize, Serialize}; +use std::any::TypeId; use approx::AbsDiffEq; use num::{One, Zero}; @@ -9,7 +10,7 @@ use crate::base::{DefaultAllocator, Matrix, Matrix2x3, OMatrix, OVector, Vector2 use crate::constraint::{SameNumberOfRows, ShapeConstraint}; use crate::dimension::{Dim, DimDiff, DimMin, DimMinimum, DimSub, U1}; use crate::storage::Storage; -use crate::RawStorage; +use crate::{Matrix2, Matrix3, RawStorage, U2, U3}; use simba::scalar::{ComplexField, RealField}; use crate::linalg::givens::GivensRotation; @@ -118,6 +119,25 @@ where ); let (nrows, ncols) = matrix.shape_generic(); let min_nrows_ncols = nrows.min(ncols); + + if TypeId::of::>() == TypeId::of::>() + && TypeId::of::() == TypeId::of::>() + { + // SAFETY: the reference transmutes are OK since we checked that the types match exactly. + let matrix: &Matrix2 = unsafe { std::mem::transmute(&matrix) }; + let result = super::svd2::svd2(matrix, compute_u, compute_v); + let typed_result: &Self = unsafe { std::mem::transmute(&result) }; + return Some(typed_result.clone()); + } else if TypeId::of::>() == TypeId::of::>() + && TypeId::of::() == TypeId::of::>() + { + // SAFETY: the reference transmutes are OK since we checked that the types match exactly. + let matrix: &Matrix3 = unsafe { std::mem::transmute(&matrix) }; + let result = super::svd3::svd3(matrix, compute_u, compute_v, eps, max_niter); + let typed_result: &Self = unsafe { std::mem::transmute(&result) }; + return Some(typed_result.clone()); + } + let dim = min_nrows_ncols.value(); let m_amax = matrix.camax(); diff --git a/src/linalg/svd2.rs b/src/linalg/svd2.rs new file mode 100644 index 00000000..b34ffb02 --- /dev/null +++ b/src/linalg/svd2.rs @@ -0,0 +1,43 @@ +use crate::{Matrix2, RealField, Vector2, SVD, U2}; + +// Implementation of the 2D SVD from https://ieeexplore.ieee.org/document/486688 +// See also https://scicomp.stackexchange.com/questions/8899/robust-algorithm-for-2-times-2-svd +pub fn svd2(m: &Matrix2, compute_u: bool, compute_v: bool) -> SVD { + let half: T = crate::convert(0.5); + let one: T = crate::convert(1.0); + + let e = (m.m11.clone() + m.m22.clone()) * half.clone(); + let f = (m.m11.clone() - m.m22.clone()) * half.clone(); + let g = (m.m21.clone() + m.m12.clone()) * half.clone(); + let h = (m.m21.clone() - m.m12.clone()) * half.clone(); + let q = (e.clone() * e.clone() + h.clone() * h.clone()).sqrt(); + let r = (f.clone() * f.clone() + g.clone() * g.clone()).sqrt(); + let sx = q.clone() + r.clone(); + let sy = q - r; + let sy_sign = if sy < T::zero() { -one.clone() } else { one }; + let singular_values = Vector2::new(sx, sy * sy_sign.clone()); + + if compute_u || compute_v { + let a1 = g.atan2(f); + let a2 = h.atan2(e); + let theta = (a2.clone() - a1.clone()) * half.clone(); + let phi = (a2 + a1) * half; + let (st, ct) = theta.sin_cos(); + let (sp, cp) = phi.sin_cos(); + + let u = Matrix2::new(cp.clone(), -sp.clone(), sp, cp); + let v_t = Matrix2::new(ct.clone(), -st.clone(), st * sy_sign.clone(), ct * sy_sign); + + SVD { + u: if compute_u { Some(u) } else { None }, + singular_values, + v_t: if compute_v { Some(v_t) } else { None }, + } + } else { + SVD { + u: None, + singular_values, + v_t: None, + } + } +} diff --git a/src/linalg/svd3.rs b/src/linalg/svd3.rs new file mode 100644 index 00000000..9d38e6e3 --- /dev/null +++ b/src/linalg/svd3.rs @@ -0,0 +1,25 @@ +use crate::{Matrix3, SVD, U3}; +use simba::scalar::RealField; + +// For the 3x3 case, on the GPU, it is much more efficient to compute the SVD +// using an eigendecomposition followed by a QR decomposition. +pub fn svd3( + m: &Matrix3, + compute_u: bool, + compute_v: bool, + eps: T, + niter: usize, +) -> Option> { + let s = m.tr_mul(&m); + let v = s.try_symmetric_eigen(eps, niter)?.eigenvectors; + let b = m * &v; + + let qr = b.qr(); + let singular_values = qr.diag_internal().map(|e| e.abs()); + + Some(SVD { + u: if compute_u { Some(qr.q()) } else { None }, + singular_values, + v_t: if compute_v { Some(v.transpose()) } else { None }, + }) +} diff --git a/tests/linalg/svd.rs b/tests/linalg/svd.rs index 030036e8..5be5c13b 100644 --- a/tests/linalg/svd.rs +++ b/tests/linalg/svd.rs @@ -97,6 +97,18 @@ mod proptest_tests { prop_assert!(v_t.is_orthogonal(1.0e-5)); } + #[test] + fn svd_static_square_3x3(m in matrix3_($scalar)) { + let svd = m.svd(true, true); + let (u, s, v_t) = (svd.u.unwrap(), svd.singular_values, svd.v_t.unwrap()); + let ds = Matrix3::from_diagonal(&s.map(|e| ComplexField::from_real(e))); + + prop_assert!(s.iter().all(|e| *e >= 0.0)); + prop_assert!(relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5)); + prop_assert!(u.is_orthogonal(1.0e-5)); + prop_assert!(v_t.is_orthogonal(1.0e-5)); + } + #[test] fn svd_pseudo_inverse(m in dmatrix_($scalar)) { let svd = m.clone().svd(true, true);