more generic QR: generalize the impl of the Indexable trait

This allows the implementation of householder reflection without relying
on knowledge of DVec. This required a new member in the Indexable trait:
the shape() function, which returns the maximum index available.
This commit is contained in:
Vincent Barrielle 2014-05-11 19:46:04 +02:00
parent 43a6c96d33
commit 6ad11edf9b
10 changed files with 182 additions and 72 deletions

View File

@ -1,11 +1,36 @@
use std::num::{Zero, Float};
use na::DVec;
use na::DMat;
use traits::operations::Transpose;
use traits::structure::ColSlice;
use traits::structure::{ColSlice, Eye, Indexable};
use traits::geometry::Norm;
use std::cmp::min;
/// Get the householder matrix corresponding to a reflexion to the hyperplane
/// defined by `vec̀ . It can be a reflexion contained in a subspace.
///
/// # Arguments
/// * `dim` - the dimension of the space the resulting matrix operates in
/// * `start` - the starting dimension of the subspace of the reflexion
/// * `vec` - the vector defining the reflection.
fn householder_matrix<N: Float,
Mat: Eye + Indexable<(uint, uint), N>,
V: Indexable<uint, N>>
(dim: uint, start: uint, vec: V) -> Mat {
let mut qk : Mat = Eye::new_identity(dim);
let stop = start + vec.shape();
assert!(stop <= dim);
for j in range(start, stop) {
for i in range(start, stop) {
unsafe {
let vv = vec.unsafe_at(i) * vec.unsafe_at(j);
let qkij = qk.unsafe_at((i, j));
qk.unsafe_set((i, j), qkij - vv - vv);
}
}
}
qk
}
/// QR decomposition using Householder reflections
/// # Arguments
/// * `m` - matrix to decompose
@ -13,42 +38,26 @@ pub fn decomp_qr<N: Clone + Float>(m: &DMat<N>) -> (DMat<N>, DMat<N>) {
let rows = m.nrows();
let cols = m.ncols();
assert!(rows >= cols);
let mut q : DMat<N> = DMat::new_identity(rows);
let mut q : DMat<N> = Eye::new_identity(rows);
let mut r = m.clone();
let subtract_reflection = |vec: DVec<N>| -> DMat<N> {
// FIXME: we don't handle the complex case here
let mut qk : DMat<N> = DMat::new_identity(rows);
let start = rows - vec.at.len();
for j in range(start, rows) {
for i in range(start, rows) {
unsafe {
let vv = vec.at_fast(i - start) * vec.at_fast(j - start);
let qkij = qk.at_fast(i, j);
qk.set_fast(i, j, qkij - vv - vv);
}
}
}
qk
};
let iterations = min(rows - 1, cols);
for ite in range(0u, iterations) {
let mut v = r.col_slice(ite, ite, rows);
let alpha =
if unsafe { v.at_fast(ite) } >= Zero::zero() {
if unsafe { v.unsafe_at(ite) } >= Zero::zero() {
-Norm::norm(&v)
}
else {
Norm::norm(&v)
};
unsafe {
let x = v.at_fast(0);
v.set_fast(0, x - alpha);
let x = v.unsafe_at(0);
v.unsafe_set(0, x - alpha);
}
let _ = v.normalize();
let qk = subtract_reflection(v);
let qk: DMat<N> = householder_matrix(rows, 0, v);
r = qk * r;
q = q * Transpose::transpose_cpy(&qk);
}

View File

@ -40,7 +40,8 @@ pub use traits::{
UniformSphereSample,
AnyVec,
VecExt,
ColSlice, RowSlice
ColSlice, RowSlice,
Eye
};
pub use structs::{

View File

@ -9,7 +9,7 @@ use traits::operations::ApproxEq;
use std::mem;
use structs::dvec::{DVec, DVecMulRhs};
use traits::operations::{Inv, Transpose, Mean, Cov};
use traits::structure::{Cast, ColSlice, RowSlice};
use traits::structure::{Cast, ColSlice, RowSlice, Eye, Indexable};
use std::fmt::{Show, Formatter, Result};
#[doc(hidden)]
@ -181,19 +181,19 @@ impl<N> DMat<N> {
// FIXME: add a function to modify the dimension (to avoid useless allocations)?
impl<N: One + Zero + Clone> DMat<N> {
impl<N: One + Zero + Clone> Eye for DMat<N> {
/// Builds an identity matrix.
///
/// # Arguments
/// * `dim` - The dimension of the matrix. A `dim`-dimensional matrix contains `dim * dim`
/// components.
#[inline]
pub fn new_identity(dim: uint) -> DMat<N> {
fn new_identity(dim: uint) -> DMat<N> {
let mut res = DMat::new_zeros(dim, dim);
for i in range(0u, dim) {
let _1: N = One::one();
res.set(i, i, _1);
res.set((i, i), _1);
}
res
@ -206,13 +206,16 @@ impl<N: Clone> DMat<N> {
i + j * self.nrows
}
}
impl<N: Clone> Indexable<(uint, uint), N> for DMat<N> {
/// Changes the value of a component of the matrix.
///
/// # Arguments
/// * `row` - 0-based index of the line to be changed
/// * `col` - 0-based index of the column to be changed
/// * `rowcol` - 0-based tuple (row, col) to be changed
#[inline]
pub fn set(&mut self, row: uint, col: uint, val: N) {
fn set(&mut self, rowcol: (uint, uint), val: N) {
let (row, col) = rowcol;
assert!(row < self.nrows);
assert!(col < self.ncols);
@ -222,7 +225,8 @@ impl<N: Clone> DMat<N> {
/// Just like `set` without bounds checking.
#[inline]
pub unsafe fn set_fast(&mut self, row: uint, col: uint, val: N) {
unsafe fn unsafe_set(&mut self, rowcol: (uint, uint), val: N) {
let (row, col) = rowcol;
let offset = self.offset(row, col);
*self.mij.as_mut_slice().unsafe_mut_ref(offset) = val
}
@ -230,20 +234,38 @@ impl<N: Clone> DMat<N> {
/// Reads the value of a component of the matrix.
///
/// # Arguments
/// * `row` - 0-based index of the line to be read
/// * `col` - 0-based index of the column to be read
/// * `rowcol` - 0-based tuple (row, col) to be read
#[inline]
pub fn at(&self, row: uint, col: uint) -> N {
fn at(&self, rowcol: (uint, uint)) -> N {
let (row, col) = rowcol;
assert!(row < self.nrows);
assert!(col < self.ncols);
unsafe { self.at_fast(row, col) }
unsafe { self.unsafe_at((row, col)) }
}
/// Just like `at` without bounds checking.
#[inline]
pub unsafe fn at_fast(&self, row: uint, col: uint) -> N {
unsafe fn unsafe_at(&self, rowcol: (uint, uint)) -> N {
let (row, col) = rowcol;
(*self.mij.as_slice().unsafe_ref(self.offset(row, col))).clone()
}
#[inline]
fn swap(&mut self, rowcol1: (uint, uint), rowcol2: (uint, uint)) {
let (row1, col1) = rowcol1;
let (row2, col2) = rowcol2;
let offset1 = self.offset(row1, col1);
let offset2 = self.offset(row2, col2);
let count = self.mij.len();
assert!(offset1 < count);
assert!(offset1 < count);
self.mij.as_mut_slice().swap(offset1, offset2);
}
fn shape(&self) -> (uint, uint) {
(self.nrows, self.ncols)
}
}
impl<N: Clone + Mul<N, N> + Add<N, N> + Zero> DMatMulRhs<N, DMat<N>> for DMat<N> {
@ -258,10 +280,11 @@ impl<N: Clone + Mul<N, N> + Add<N, N> + Zero> DMatMulRhs<N, DMat<N>> for DMat<N>
unsafe {
for k in range(0u, left.ncols) {
acc = acc + left.at_fast(i, k) * right.at_fast(k, j);
acc = acc
+ left.unsafe_at((i, k)) * right.unsafe_at((k, j));
}
res.set_fast(i, j, acc);
res.unsafe_set((i, j), acc);
}
}
}
@ -282,7 +305,7 @@ DMatMulRhs<N, DVec<N>> for DVec<N> {
for j in range(0u, left.ncols) {
unsafe {
acc = acc + left.at_fast(i, j) * right.at_fast(j);
acc = acc + left.unsafe_at((i, j)) * right.unsafe_at(j);
}
}
@ -306,7 +329,7 @@ DVecMulRhs<N, DVec<N>> for DMat<N> {
for j in range(0u, right.nrows) {
unsafe {
acc = acc + left.at_fast(j) * right.at_fast(j, i);
acc = acc + left.unsafe_at(j) * right.unsafe_at((j, i));
}
}
@ -335,7 +358,7 @@ Inv for DMat<N> {
assert!(self.nrows == self.ncols);
let dim = self.nrows;
let mut res: DMat<N> = DMat::new_identity(dim);
let mut res: DMat<N> = Eye::new_identity(dim);
let _0T: N = Zero::zero();
// inversion using Gauss-Jordan elimination
@ -347,7 +370,7 @@ Inv for DMat<N> {
let mut n0 = k; // index of a non-zero entry
while n0 != dim {
if unsafe { self.at_fast(n0, k) } != _0T {
if unsafe { self.unsafe_at((n0, k)) } != _0T {
break;
}
@ -370,30 +393,30 @@ Inv for DMat<N> {
}
unsafe {
let pivot = self.at_fast(k, k);
let pivot = self.unsafe_at((k, k));
for j in range(k, dim) {
let selfval = self.at_fast(k, j) / pivot;
self.set_fast(k, j, selfval);
let selfval = self.unsafe_at((k, j)) / pivot;
self.unsafe_set((k, j), selfval);
}
for j in range(0u, dim) {
let resval = res.at_fast(k, j) / pivot;
res.set_fast(k, j, resval);
let resval = res.unsafe_at((k, j)) / pivot;
res.unsafe_set((k, j), resval);
}
for l in range(0u, dim) {
if l != k {
let normalizer = self.at_fast(l, k);
let normalizer = self.unsafe_at((l, k));
for j in range(k, dim) {
let selfval = self.at_fast(l, j) - self.at_fast(k, j) * normalizer;
self.set_fast(l, j, selfval);
let selfval = self.unsafe_at((l, j)) - self.unsafe_at((k, j)) * normalizer;
self.unsafe_set((l, j), selfval);
}
for j in range(0u, dim) {
let resval = res.at_fast(l, j) - res.at_fast(k, j) * normalizer;
res.set_fast(l, j, resval);
let resval = res.unsafe_at((l, j)) - res.unsafe_at((k, j)) * normalizer;
res.unsafe_set((l, j), resval);
}
}
}
@ -422,7 +445,7 @@ impl<N: Clone> Transpose for DMat<N> {
for i in range(0u, m.nrows) {
for j in range(0u, m.ncols) {
unsafe {
res.set_fast(j, i, m.at_fast(i, j))
res.unsafe_set((j, i), m.unsafe_at((i, j)))
}
}
}
@ -460,8 +483,8 @@ impl<N: Num + Cast<f32> + Clone> Mean<DVec<N>> for DMat<N> {
for i in range(0u, m.nrows) {
for j in range(0u, m.ncols) {
unsafe {
let acc = res.at_fast(j) + m.at_fast(i, j) * normalizer;
res.set_fast(j, acc);
let acc = res.unsafe_at(j) + m.unsafe_at((i, j)) * normalizer;
res.unsafe_set(j, acc);
}
}
}
@ -482,7 +505,7 @@ impl<N: Clone + Num + Cast<f32> + DMatDivRhs<N, DMat<N>> + ToStr > Cov<DMat<N>>
for i in range(0u, m.nrows) {
for j in range(0u, m.ncols) {
unsafe {
centered.set_fast(i, j, m.at_fast(i, j) - mean.at_fast(j));
centered.unsafe_set((i, j), m.unsafe_at((i, j)) - mean.unsafe_at(j));
}
}
}
@ -520,7 +543,7 @@ impl<N: Clone> RowSlice<DVec<N>> for DMat<N> {
let mut slice_idx = 0u;
for col_id in range(col_start, col_end) {
unsafe {
slice.set_fast(slice_idx, self.at_fast(row_id, col_id));
slice.unsafe_set(slice_idx, self.unsafe_at((row_id, col_id)));
}
slice_idx += 1;
}
@ -553,7 +576,7 @@ impl<N: Show + Clone> Show for DMat<N> {
fn fmt(&self, form:&mut Formatter) -> Result {
for i in range(0u, self.nrows()) {
for j in range(0u, self.ncols()) {
let _ = write!(form.buf, "{} ", self.at(i, j));
let _ = write!(form.buf, "{} ", self.at((i, j)));
}
let _ = write!(form.buf, "\n");
}

View File

@ -9,7 +9,7 @@ use std::slice::{Items, MutItems};
use traits::operations::ApproxEq;
use std::iter::FromIterator;
use traits::geometry::{Dot, Norm};
use traits::structure::{Iterable, IterableMut};
use traits::structure::{Iterable, IterableMut, Indexable};
#[doc(hidden)]
mod metal;
@ -48,11 +48,42 @@ impl<N: Zero + Clone> DVec<N> {
}
}
impl<N: Clone> DVec<N> {
/// Indexing without bounds checking.
pub unsafe fn at_fast(&self, i: uint) -> N {
impl<N: Clone> Indexable<uint, N> for DVec<N> {
fn at(&self, i: uint) -> N {
assert!(i < self.at.len());
unsafe {
self.unsafe_at(i)
}
}
fn set(&mut self, i: uint, val: N) {
assert!(i < self.at.len());
unsafe {
self.unsafe_set(i, val);
}
}
fn swap(&mut self, i: uint, j: uint) {
assert!(i < self.at.len());
assert!(j < self.at.len());
self.at.as_mut_slice().swap(i, j);
}
fn shape(&self) -> uint {
self.at.len()
}
#[inline]
unsafe fn unsafe_at(&self, i: uint) -> N {
(*self.at.as_slice().unsafe_ref(i)).clone()
}
#[inline]
unsafe fn unsafe_set(&mut self, i: uint, val: N) {
*self.at.as_mut_slice().unsafe_mut_ref(i) = val
}
}
impl<N: One + Clone> DVec<N> {
@ -86,11 +117,6 @@ impl<N> DVec<N> {
}
}
#[inline]
pub unsafe fn set_fast(&mut self, i: uint, val: N) {
*self.at.as_mut_slice().unsafe_mut_ref(i) = val
}
/// Gets a reference to of this vector data.
#[inline]
pub fn as_vec<'r>(&'r self) -> &'r [N] {
@ -261,7 +287,7 @@ impl<N: Num + Clone> Dot<N> for DVec<N> {
let mut res: N = Zero::zero();
for i in range(0u, a.at.len()) {
res = res + unsafe { a.at_fast(i) * b.at_fast(i) };
res = res + unsafe { a.unsafe_at(i) * b.unsafe_at(i) };
}
res
@ -272,7 +298,7 @@ impl<N: Num + Clone> Dot<N> for DVec<N> {
let mut res: N = Zero::zero();
for i in range(0u, a.at.len()) {
res = res + unsafe { (a.at_fast(i) - b.at_fast(i)) * c.at_fast(i) };
res = res + unsafe { (a.unsafe_at(i) - b.unsafe_at(i)) * c.unsafe_at(i) };
}
res

View File

@ -9,7 +9,8 @@ use std::slice::{Items, MutItems};
use structs::vec::{Vec1, Vec2, Vec3, Vec4, Vec5, Vec6, Vec1MulRhs, Vec4MulRhs,
Vec5MulRhs, Vec6MulRhs};
use traits::structure::{Cast, Row, Col, Iterable, IterableMut, Dim, Indexable};
use traits::structure::{Cast, Row, Col, Iterable, IterableMut, Dim, Indexable,
Eye};
use traits::operations::{Absolute, Transpose, Inv, Outer};
use traits::geometry::{ToHomogeneous, FromHomogeneous};
@ -34,6 +35,8 @@ pub struct Mat1<N> {
pub m11: N
}
eye_impl!(Mat1, 1, m11)
double_dispatch_binop_decl_trait!(Mat1, Mat1MulRhs)
double_dispatch_binop_decl_trait!(Mat1, Mat1DivRhs)
double_dispatch_binop_decl_trait!(Mat1, Mat1AddRhs)
@ -127,6 +130,8 @@ pub struct Mat2<N> {
pub m12: N, pub m22: N
}
eye_impl!(Mat2, 2, m11, m22)
double_dispatch_binop_decl_trait!(Mat2, Mat2MulRhs)
double_dispatch_binop_decl_trait!(Mat2, Mat2DivRhs)
double_dispatch_binop_decl_trait!(Mat2, Mat2AddRhs)
@ -225,6 +230,8 @@ pub struct Mat3<N> {
pub m13: N, pub m23: N, pub m33: N
}
eye_impl!(Mat3, 3, m11, m22, m33)
double_dispatch_binop_decl_trait!(Mat3, Mat3MulRhs)
double_dispatch_binop_decl_trait!(Mat3, Mat3DivRhs)
double_dispatch_binop_decl_trait!(Mat3, Mat3AddRhs)
@ -337,6 +344,8 @@ pub struct Mat4<N> {
pub m14: N, pub m24: N, pub m34: N, pub m44: N
}
eye_impl!(Mat4, 4, m11, m22, m33, m44)
double_dispatch_binop_decl_trait!(Mat4, Mat4MulRhs)
double_dispatch_binop_decl_trait!(Mat4, Mat4DivRhs)
double_dispatch_binop_decl_trait!(Mat4, Mat4AddRhs)
@ -501,6 +510,8 @@ pub struct Mat5<N> {
pub m15: N, pub m25: N, pub m35: N, pub m45: N, pub m55: N
}
eye_impl!(Mat5, 5, m11, m22, m33, m44, m55)
double_dispatch_binop_decl_trait!(Mat5, Mat5MulRhs)
double_dispatch_binop_decl_trait!(Mat5, Mat5DivRhs)
double_dispatch_binop_decl_trait!(Mat5, Mat5AddRhs)
@ -681,6 +692,8 @@ pub struct Mat6<N> {
pub m16: N, pub m26: N, pub m36: N, pub m46: N, pub m56: N, pub m66: N
}
eye_impl!(Mat6, 6, m11, m22, m33, m44, m55, m66)
double_dispatch_binop_decl_trait!(Mat6, Mat6MulRhs)
double_dispatch_binop_decl_trait!(Mat6, Mat6DivRhs)
double_dispatch_binop_decl_trait!(Mat6, Mat6AddRhs)

View File

@ -98,6 +98,20 @@ macro_rules! scalar_add_impl(
)
)
macro_rules! eye_impl(
($t: ident, $ndim: expr, $($comp_diagN: ident),+) => (
impl<N: Zero + One> Eye for $t<N> {
fn new_identity(dim: uint) -> $t<N> {
assert!(dim == $ndim);
let mut eye: $t<N> = Zero::zero();
$(eye.$comp_diagN = One::one();)+
eye
}
}
)
)
macro_rules! scalar_sub_impl(
($t: ident, $n: ident, $trhs: ident, $comp0: ident $(,$compN: ident)*) => (
impl $trhs<$n, $t<$n>> for $n {
@ -193,6 +207,11 @@ macro_rules! indexable_impl(
}
}
#[inline]
fn shape(&self) -> (uint, uint) {
($dim, $dim)
}
#[inline]
unsafe fn unsafe_at(&self, (i, j): (uint, uint)) -> N {
(*cast::transmute::<&$t<N>, &[N, ..$dim * $dim]>(self).unsafe_ref(i + j * $dim)).clone()

View File

@ -25,6 +25,11 @@ impl<N> Indexable<uint, N> for vec::Vec0<N> {
fn set(&mut self, _: uint, _: N) {
}
#[inline]
fn shape(&self) -> uint {
0
}
#[inline]
fn swap(&mut self, _: uint, _: uint) {
}

View File

@ -165,6 +165,11 @@ macro_rules! indexable_impl(
}
}
#[inline]
fn shape(&self) -> uint {
$dim
}
#[inline]
fn swap(&mut self, i1: uint, i2: uint) {
unsafe {

View File

@ -6,7 +6,7 @@ pub use self::geometry::{AbsoluteRotate, Cross, CrossMatrix, Dot, FromHomogeneou
pub use self::structure::{FloatVec, FloatVecExt, Basis, Cast, Col, Dim, Indexable,
Iterable, IterableMut, Mat, Row, AnyVec, VecExt,
ColSlice, RowSlice};
ColSlice, RowSlice, Eye};
pub use self::operations::{Absolute, ApproxEq, Cov, Inv, LMul, Mean, Outer, PartialOrd, RMul,
ScalarAdd, ScalarSub, Transpose};

View File

@ -19,6 +19,12 @@ pub trait Mat<R, C> : Row<R> + Col<C> + RMul<R> + LMul<C> { }
impl<M: Row<R> + Col<C> + RMul<R> + LMul<C>, R, C> Mat<R, C> for M {
}
/// Trait for constructing the identity matrix
pub trait Eye {
/// Return the identity matrix of specified dimension
fn new_identity(dim: uint) -> Self;
}
// XXX: we keep ScalarAdd and ScalarSub here to avoid trait impl conflict (overriding) between the
// different Add/Sub traits. This is _so_ unfortunate…
@ -126,6 +132,9 @@ pub trait Indexable<Index, Res> {
/// Swaps the `i`-th element of `self` with its `j`-th element.
fn swap(&mut self, i: Index, j: Index);
/// Returns the shape of the iterable range
fn shape(&self) -> Index;
/// Reads the `i`-th element of `self`.
///
/// `i` is not checked.