From b197959e2bd9c593a946872f725f6ccc1437e535 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 7 Aug 2015 14:44:25 +0200 Subject: [PATCH 1/5] Implemented Cholesky decomposition with tests --- src/lib.rs | 3 +- src/linalg/decompositions.rs | 47 +++++++++++++++++++ src/linalg/mod.rs | 2 +- tests/mat.rs | 89 +++++++++++++++++++++++++++++++++++- 4 files changed, 138 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3f2d7975..1c8a738c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -150,7 +150,8 @@ pub use structs::{ pub use linalg::{ qr, - householder_matrix + householder_matrix, + cholesky }; mod structs; diff --git a/src/linalg/decompositions.rs b/src/linalg/decompositions.rs index 78de85ed..01fad087 100644 --- a/src/linalg/decompositions.rs +++ b/src/linalg/decompositions.rs @@ -115,3 +115,50 @@ pub fn eigen_qr(m: &M, eps: &N, niter: usize) -> (M, V) (eigenvectors, eigenvalues.diag()) } + +/// Cholesky decomposition G of a square symmetric positive definite matrix A, such that A = G * G^T +/// +/// # Arguments +/// * `m` - square symmetric positive definite matrix to decompose +pub fn cholesky(m: &M) -> Result + where N: BaseFloat, + VS: Indexable + Norm, + M: Indexable<(usize, usize), N> + SquareMat + Add + + Sub + ColSlice + + ApproxEq + Copy { + + let mut out = m.clone(); + + for i in 0..out.nrows() { + for j in 0..(i+1) { + + let mut sum: N = out[(i,j)]; + + for k in 0..j { + sum = sum - out[(i, k)] * out[(j, k)]; + } + + if i > j { + out[(i, j)] = sum / out[(j, j)]; + } + else if i < j { + out[(i,j)] = N::zero(); + } + else if sum > N::zero() { + out[(i,i)] = sum.sqrt(); + } + else { + return Err("Cholesky: Input matrix is not positive definite to machine precision"); + } + } + } + + for i in 0..out.nrows() { + for j in i+1..out.ncols() { + out[(i,j)] = N::zero(); + } + + } + + return Ok(out); +} \ No newline at end of file diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 4e171ca5..80b22827 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -1,4 +1,4 @@ -pub use self::decompositions::{qr, eigen_qr, householder_matrix}; +pub use self::decompositions::{qr, eigen_qr, householder_matrix, cholesky}; mod decompositions; diff --git a/tests/mat.rs b/tests/mat.rs index 88f98e0d..62edbfd2 100644 --- a/tests/mat.rs +++ b/tests/mat.rs @@ -3,7 +3,7 @@ extern crate rand; use rand::random; use na::{Vec1, Vec3, Mat1, Mat2, Mat3, Mat4, Mat5, Mat6, Rot2, Rot3, Persp3, PerspMat3, Ortho3, - OrthoMat3, DMat, DVec, Row, Col, BaseFloat}; + OrthoMat3, DMat, DVec, Row, Col, BaseFloat, Diag}; macro_rules! test_inv_mat_impl( ($t: ty) => ( @@ -41,6 +41,30 @@ macro_rules! test_qr_impl( ); ); +macro_rules! test_cholesky_impl( + ($t: ty) => ( + for _ in (0usize .. 10000) { + + // construct symmetric positive definite matrix + let mut randmat : $t = random(); + let mut diagmat : $t = Diag::from_diag(&na::diag(&randmat)); + + diagmat = na::abs(&diagmat) + 1.0; + randmat = randmat * diagmat * na::transpose(&randmat); + + let result = na::cholesky(&randmat); + + match result { + Ok(v) => { + let recomp = v * na::transpose(&v); + assert!(na::approx_eq(&randmat, &recomp)); + }, + Err(_) => assert!(false), + } + } + ); +); + // NOTE: deactivated untile we get a better convergence rate. // macro_rules! test_eigen_qr_impl( // ($t: ty) => { @@ -600,3 +624,66 @@ fn test_ortho() { assert!(na::approx_eq(&pm.znear(), &24.0)); assert!(na::approx_eq(&pm.zfar(), &61.0)); } + +#[test] +fn test_cholesky_const() { + + let a : Mat3 = Mat3::::new(1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 3.0); + let g : Mat3 = Mat3::::new(1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0); + + let result = na::cholesky(&a); + + match result { + Ok(v) => { + assert!(na::approx_eq(&v, &g)); + + let recomp = v * na::transpose(&v); + assert!(na::approx_eq(&recomp, &a)); + }, + + Err(_) => assert!(false), + } +} + +#[test] +fn test_cholesky_not_spd() { + + let a : Mat3 = Mat3::::new(1.0, 2.0, 3.0, 3.0, 2.0, 1.0, 1.0, 1.0, 1.0); + + let result = na::cholesky(&a); + + match result { + Ok(_) => assert!(false), + Err(_) => assert!(true), + } +} + +#[test] +fn test_cholesky_mat1() { + test_cholesky_impl!(Mat1); +} + +#[test] +fn test_cholesky_mat2() { + test_cholesky_impl!(Mat2); +} + +#[test] +fn test_cholesky_mat3() { + test_cholesky_impl!(Mat3); +} + +#[test] +fn test_cholesky_mat4() { + test_cholesky_impl!(Mat4); +} + +#[test] +fn test_cholesky_mat5() { + test_cholesky_impl!(Mat5); +} + +#[test] +fn test_cholesky_mat6() { + test_cholesky_impl!(Mat6); +} \ No newline at end of file From dc571838bbc8d8bf3a1f846fa8c22c3ff93cc4eb Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 7 Aug 2015 15:03:38 +0200 Subject: [PATCH 2/5] Added check for symmetricity of input matrix --- src/linalg/decompositions.rs | 6 +++++- tests/mat.rs | 13 +++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/linalg/decompositions.rs b/src/linalg/decompositions.rs index 01fad087..e56e8cc4 100644 --- a/src/linalg/decompositions.rs +++ b/src/linalg/decompositions.rs @@ -127,7 +127,11 @@ pub fn cholesky(m: &M) -> Result Sub + ColSlice + ApproxEq + Copy { - let mut out = m.clone(); + let mut out = m.clone().transpose(); + + if !ApproxEq::approx_eq(&out, &m) { + return Err("Cholesky: Input matrix is not symmetric"); + } for i in 0..out.nrows() { for j in 0..(i+1) { diff --git a/tests/mat.rs b/tests/mat.rs index 62edbfd2..28878b40 100644 --- a/tests/mat.rs +++ b/tests/mat.rs @@ -658,6 +658,19 @@ fn test_cholesky_not_spd() { } } +#[test] +fn test_cholesky_not_symmetric() { + + let a : Mat2 = Mat2::::new(1.0, 1.0, -1.0, 1.0); + + let result = na::cholesky(&a); + + match result { + Ok(_) => assert!(false), + Err(_) => assert!(true), + } +} + #[test] fn test_cholesky_mat1() { test_cholesky_impl!(Mat1); From 9bb6325846269b3acb999a3a0b20b74511fa8611 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 8 Aug 2015 14:52:57 +0200 Subject: [PATCH 3/5] Made tests more readable --- tests/mat.rs | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/tests/mat.rs b/tests/mat.rs index 28878b40..12fc7a0a 100644 --- a/tests/mat.rs +++ b/tests/mat.rs @@ -633,16 +633,13 @@ fn test_cholesky_const() { let result = na::cholesky(&a); - match result { - Ok(v) => { - assert!(na::approx_eq(&v, &g)); + assert!(result.is_ok()); - let recomp = v * na::transpose(&v); - assert!(na::approx_eq(&recomp, &a)); - }, + let v = result.unwrap(); + assert!(na::approx_eq(&v, &g)); - Err(_) => assert!(false), - } + let recomp = v * na::transpose(&v); + assert!(na::approx_eq(&recomp, &a)); } #[test] @@ -652,10 +649,7 @@ fn test_cholesky_not_spd() { let result = na::cholesky(&a); - match result { - Ok(_) => assert!(false), - Err(_) => assert!(true), - } + assert!(result.is_err()); } #[test] @@ -665,10 +659,7 @@ fn test_cholesky_not_symmetric() { let result = na::cholesky(&a); - match result { - Ok(_) => assert!(false), - Err(_) => assert!(true), - } + assert!(result.is_err()); } #[test] From 1716dd86dbb0ba2cf08ee90d4beabcbfdf0cfe22 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 8 Aug 2015 17:22:47 +0200 Subject: [PATCH 4/5] Made tests more readable - missed a function --- tests/mat.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/mat.rs b/tests/mat.rs index 12fc7a0a..f4c70f83 100644 --- a/tests/mat.rs +++ b/tests/mat.rs @@ -54,13 +54,11 @@ macro_rules! test_cholesky_impl( let result = na::cholesky(&randmat); - match result { - Ok(v) => { - let recomp = v * na::transpose(&v); - assert!(na::approx_eq(&randmat, &recomp)); - }, - Err(_) => assert!(false), - } + assert!(result.is_ok()); + + let v = result.unwrap(); + let recomp = v * na::transpose(&v); + assert!(na::approx_eq(&randmat, &recomp)); } ); ); From 89bbe0f4b45645c0af21c164d3b3171d46d59c39 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 8 Aug 2015 17:52:16 +0200 Subject: [PATCH 5/5] Removed unused code --- src/linalg/decompositions.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/linalg/decompositions.rs b/src/linalg/decompositions.rs index e56e8cc4..a67ef389 100644 --- a/src/linalg/decompositions.rs +++ b/src/linalg/decompositions.rs @@ -145,9 +145,6 @@ pub fn cholesky(m: &M) -> Result if i > j { out[(i, j)] = sum / out[(j, j)]; } - else if i < j { - out[(i,j)] = N::zero(); - } else if sum > N::zero() { out[(i,i)] = sum.sqrt(); }