diff --git a/src/linalg/udu.rs b/src/linalg/udu.rs index eada8e99..953359e8 100644 --- a/src/linalg/udu.rs +++ b/src/linalg/udu.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; use crate::allocator::Allocator; -use crate::base::{DefaultAllocator, MatrixN}; +use crate::base::{DefaultAllocator, MatrixN, VectorN, U1}; use crate::dimension::Dim; use simba::scalar::ComplexField; @@ -11,24 +11,25 @@ use simba::scalar::ComplexField; #[derive(Clone, Debug)] pub struct UDU where - DefaultAllocator: Allocator, + DefaultAllocator: Allocator + Allocator, { /// The upper triangular matrix resulting from the factorization pub u: MatrixN, /// The diagonal matrix resulting from the factorization - pub d: MatrixN, + pub d: VectorN, } impl Copy for UDU where - DefaultAllocator: Allocator, + DefaultAllocator: Allocator + Allocator, + VectorN: Copy, MatrixN: Copy, { } impl UDU where - DefaultAllocator: Allocator, + DefaultAllocator: Allocator + Allocator, { /// Computes the UDU^T factorization /// NOTE: The provided matrix MUST be symmetric, and no verification is done in this regard. @@ -37,31 +38,31 @@ where let n = p.ncols(); let n_as_dim = D::from_usize(n); - let mut d = MatrixN::::zeros_generic(n_as_dim, n_as_dim); + let mut d = VectorN::::zeros_generic(n_as_dim, U1); let mut u = MatrixN::::zeros_generic(n_as_dim, n_as_dim); - d[(n - 1, n - 1)] = p[(n - 1, n - 1)]; + d[n - 1] = p[(n - 1, n - 1)]; u[(n - 1, n - 1)] = N::one(); for j in (0..n - 1).rev() { - u[(j, n - 1)] = p[(j, n - 1)] / d[(n - 1, n - 1)]; + u[(j, n - 1)] = p[(j, n - 1)] / d[n - 1]; } for j in (0..n - 1).rev() { for k in j + 1..n { - d[(j, j)] = d[(j, j)] + d[(k, k)] * u[(j, k)].powi(2); + d[j] = d[j] + d[k] * u[(j, k)].powi(2); } - d[(j, j)] = p[(j, j)] - d[(j, j)]; + d[j] = p[(j, j)] - d[j]; for i in (0..=j).rev() { for k in j + 1..n { - u[(i, j)] = u[(i, j)] + d[(k, k)] * u[(j, k)] * u[(i, k)]; + u[(i, j)] = u[(i, j)] + d[k] * u[(j, k)] * u[(i, k)]; } u[(i, j)] = p[(i, j)] - u[(i, j)]; - u[(i, j)] /= d[(j, j)]; + u[(i, j)] /= d[j]; } u[(j, j)] = N::one(); @@ -69,4 +70,9 @@ where Self { u, d } } + + /// Returns the diagonal elements as a matrix + pub fn d_matrix(&self) -> MatrixN { + MatrixN::from_diagonal(&self.d) + } } diff --git a/tests/linalg/udu.rs b/tests/linalg/udu.rs index 0d457db5..3304d73b 100644 --- a/tests/linalg/udu.rs +++ b/tests/linalg/udu.rs @@ -11,7 +11,7 @@ fn udu_simple() { let udu = UDU::new(m); // Rebuild - let p = udu.u * udu.d * udu.u.transpose(); + let p = udu.u * udu.d_matrix() * udu.u.transpose(); assert!(relative_eq!(m, p, epsilon = 3.0e-16)); } @@ -39,7 +39,7 @@ mod quickcheck_tests { let m = m.map(|e| e.0); let udu = UDU::new(m.clone()); - let p = &udu.u * &udu.d * &udu.u.transpose(); + let p = &udu.u * &udu.d_matrix() * &udu.u.transpose(); relative_eq!(m, p, epsilon = 1.0e-7) } @@ -48,7 +48,7 @@ mod quickcheck_tests { let m = m.map(|e| e.0); let udu = UDU::new(m.clone()); - let p = udu.u * udu.d * udu.u.transpose(); + let p = udu.u * udu.d_matrix() * udu.u.transpose(); relative_eq!(m, p, epsilon = 3.0e-16) }