From 5611307b4de32b04047c6f3a3da4c3ab643a0372 Mon Sep 17 00:00:00 2001 From: Vincent Barrielle Date: Fri, 9 May 2014 22:14:37 +0200 Subject: [PATCH] QR decomposition depends less on DMat internals --- src/lib.rs | 1 + src/linalg/decompositions.rs | 58 ++++++++++++++++++++++ src/linalg/mod.rs | 4 ++ src/na.rs | 8 +++- src/structs/dmat.rs | 93 +++++++++++++----------------------- src/structs/mod.rs | 1 - src/traits/mod.rs | 2 +- src/traits/operations.rs | 12 +++++ 8 files changed, 116 insertions(+), 63 deletions(-) create mode 100644 src/linalg/decompositions.rs create mode 100644 src/linalg/mod.rs diff --git a/src/lib.rs b/src/lib.rs index 3650a4d6..125c679a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -119,6 +119,7 @@ extern crate test; pub mod na; mod structs; mod traits; +mod linalg; // mod lower_triangular; // mod chol; diff --git a/src/linalg/decompositions.rs b/src/linalg/decompositions.rs new file mode 100644 index 00000000..bee21f37 --- /dev/null +++ b/src/linalg/decompositions.rs @@ -0,0 +1,58 @@ +use std::num::{Zero, Float}; +use na::DVec; +use na::DMat; +use traits::operations::{Transpose, ColSlice}; +use traits::geometry::Norm; +use std::cmp::min; + +/// QR decomposition using Householder reflections +/// # Arguments +/// * `m` matrix to decompose +pub fn decomp_qr(m: &DMat) -> (DMat, DMat) { + let rows = m.nrows(); + let cols = m.ncols(); + assert!(rows >= cols); + let mut q : DMat = DMat::new_identity(rows); + let mut r = m.clone(); + + let subtract_reflection = |vec: DVec| -> DMat { + // FIXME: we don't handle the complex case here + let mut qk : DMat = DMat::new_identity(rows); + let start = rows - vec.at.len(); + for j in range(start, rows) { + for i in range(start, rows) { + unsafe { + let vv = vec.at_fast(i-start)*vec.at_fast(j-start); + let qkij = qk.at_fast(i,j); + qk.set_fast(i, j, qkij - vv - vv); + } + } + } + qk + }; + + let iterations = min(rows-1, cols); + + for ite in range(0u, iterations) { + let mut v = r.col_slice(ite, ite, rows); + //let mut v = r.col_slice>(ite, rows-ite, rows); + let alpha = + if unsafe { v.at_fast(ite) } >= Zero::zero() { + -Norm::norm(&v) + } + else { + Norm::norm(&v) + }; + unsafe { + let x = v.at_fast(0); + v.set_fast(0, x - alpha); + } + let _ = v.normalize(); + let qk = subtract_reflection(v); + r = qk * r; + q = q * Transpose::transpose_cpy(&qk); + } + + (q, r) +} + diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs new file mode 100644 index 00000000..6c2c6d63 --- /dev/null +++ b/src/linalg/mod.rs @@ -0,0 +1,4 @@ + +pub use self::decompositions::decomp_qr; + +mod decompositions; diff --git a/src/na.rs b/src/na.rs index c837c3b6..415eb8f2 100644 --- a/src/na.rs +++ b/src/na.rs @@ -39,12 +39,12 @@ pub use traits::{ Transpose, UniformSphereSample, AnyVec, - VecExt + VecExt, + ColSlice, RowSlice }; pub use structs::{ Identity, - decomp_qr, DMat, DVec, Iso2, Iso3, Iso4, Mat1, Mat2, Mat3, Mat4, @@ -53,6 +53,10 @@ pub use structs::{ Vec0, Vec1, Vec2, Vec3, Vec4, Vec5, Vec6 }; +pub use linalg::{ + decomp_qr +}; + /// Traits to work around the language limitations related to operator overloading. /// /// The trait names are formed by: diff --git a/src/structs/dmat.rs b/src/structs/dmat.rs index dda9cb55..ac9cac7c 100644 --- a/src/structs/dmat.rs +++ b/src/structs/dmat.rs @@ -4,14 +4,12 @@ use rand::Rand; use rand; -use std::num::{One, Zero, Float}; +use std::num::{One, Zero}; use traits::operations::ApproxEq; use std::mem; use structs::dvec::{DVec, DVecMulRhs}; -use traits::operations::{Inv, Transpose, Mean, Cov}; +use traits::operations::{Inv, Transpose, Mean, Cov, ColSlice, RowSlice}; use traits::structure::Cast; -use traits::geometry::Norm; -use std::cmp::min; use std::fmt::{Show, Formatter, Result}; #[doc(hidden)] @@ -497,61 +495,38 @@ impl + DMatDivRhs> + ToStr > Cov> } } - -/// QR decomposition using Householder reflections -/// # Arguments -/// * `m` matrix to decompose -pub fn decomp_qr(m: &DMat) -> (DMat, DMat) { - let rows = m.nrows(); - let cols = m.ncols(); - assert!(rows >= cols); - let mut q : DMat = DMat::new_identity(rows); - let mut r = m.clone(); - - let subtract_reflection = |vec: DVec| -> DMat { - // FIXME: we don't handle the complex case here - let mut qk : DMat = DMat::new_identity(rows); - let start = rows - vec.at.len(); - for j in range(start, rows) { - for i in range(start, rows) { - unsafe { - let vv = vec.at_fast(i-start)*vec.at_fast(j-start); - let qkij = qk.at_fast(i,j); - qk.set_fast(i, j, qkij - vv - vv); - } - } - } - qk - }; - - let iterations = min(rows-1, cols); - - for ite in range(0u, iterations) { - // we get the ite-th column truncated from its first ite elements, - // this is fast thanks to the matrix being column major - let start= m.offset(ite, ite); - let stop = m.offset(rows, ite); - let mut v = DVec::from_vec(rows - ite, r.mij.slice(start, stop)); - let alpha = - if unsafe { v.at_fast(ite) } >= Zero::zero() { - -Norm::norm(&v) - } - else { - Norm::norm(&v) - }; - unsafe { - let x = v.at_fast(0); - v.set_fast(0, x - alpha); - } - let _ = v.normalize(); - let qk = subtract_reflection(v); - r = qk * r; - q = q * Transpose::transpose_cpy(&qk); - } - - (q, r) +impl ColSlice> for DMat { + fn col_slice(&self, col_id :uint, row_start: uint, row_end: uint) -> DVec { + assert!(col_id < self.ncols); + assert!(row_start < row_end); + assert!(row_end <= self.nrows); + // we can init from slice thanks to the matrix being column major + let start= self.offset(row_start, col_id); + let stop = self.offset(row_end, col_id); + let slice = DVec::from_vec( + row_end - row_start, self.mij.slice(start, stop)); + slice + } } +impl RowSlice> for DMat { + fn row_slice(&self, row_id :uint, col_start: uint, col_end: uint) -> DVec { + assert!(row_id < self.nrows); + assert!(col_start < col_end); + assert!(col_end <= self.ncols); + let mut slice : DVec = unsafe { + DVec::new_uninitialized(self.nrows) + }; + let mut slice_idx = 0u; + for col_id in range(col_start, col_end) { + unsafe { + slice.set_fast(slice_idx, self.at_fast(row_id, col_id)); + } + slice_idx += 1; + } + slice + } +} impl> ApproxEq for DMat { #[inline] @@ -578,9 +553,9 @@ impl Show for DMat { fn fmt(&self, form:&mut Formatter) -> Result { for i in range(0u, self.nrows()) { for j in range(0u, self.ncols()) { - write!(form.buf, "{} ", self.at(i, j)); + let _ = write!(form.buf, "{} ", self.at(i, j)); } - write!(form.buf, "\n"); + let _ = write!(form.buf, "\n"); } write!(form.buf, "\n") } diff --git a/src/structs/mod.rs b/src/structs/mod.rs index b2516d02..f7ebcac4 100644 --- a/src/structs/mod.rs +++ b/src/structs/mod.rs @@ -1,7 +1,6 @@ //! Data structures and implementations. pub use self::dmat::DMat; -pub use self::dmat::decomp_qr; pub use self::dvec::DVec; pub use self::vec::{Vec0, Vec1, Vec2, Vec3, Vec4, Vec5, Vec6}; pub use self::mat::{Identity, Mat1, Mat2, Mat3, Mat4, Mat5, Mat6}; diff --git a/src/traits/mod.rs b/src/traits/mod.rs index 8a62fab5..c973e80a 100644 --- a/src/traits/mod.rs +++ b/src/traits/mod.rs @@ -8,7 +8,7 @@ pub use self::structure::{FloatVec, FloatVecExt, Basis, Cast, Col, Dim, Indexabl Iterable, IterableMut, Mat, Row, AnyVec, VecExt}; pub use self::operations::{Absolute, ApproxEq, Cov, Inv, LMul, Mean, Outer, PartialOrd, RMul, - ScalarAdd, ScalarSub, Transpose}; + ScalarAdd, ScalarSub, Transpose, ColSlice, RowSlice}; pub use self::operations::{PartialOrdering, PartialLess, PartialEqual, PartialGreater, NotComparable}; pub mod geometry; diff --git a/src/traits/operations.rs b/src/traits/operations.rs index 0f836fbc..34f38d9a 100644 --- a/src/traits/operations.rs +++ b/src/traits/operations.rs @@ -244,6 +244,18 @@ pub trait Mean { fn mean(&Self) -> N; } +/// Trait for objects that support column slicing +pub trait ColSlice { + /// Returns a view to a slice of a column of a matrix. + fn col_slice(&self, col_id :uint, row_start: uint, row_end: uint) -> VecLike; +} + +/// Trait for objects that support column slicing +pub trait RowSlice { + /// Returns a view to a slice of a row of a matrix. + fn row_slice(&self, row_id :uint, col_start: uint, col_end: uint) -> VecLike; +} + // /// Cholesky decomposition. // pub trait Chol { // /// Performs the cholesky decomposition on `self`. The resulting upper-triangular matrix is