diff --git a/src/linalg/decompositions.rs b/src/linalg/decompositions.rs index 47409224..3cc1a0f5 100644 --- a/src/linalg/decompositions.rs +++ b/src/linalg/decompositions.rs @@ -1,11 +1,36 @@ use std::num::{Zero, Float}; -use na::DVec; use na::DMat; use traits::operations::Transpose; -use traits::structure::ColSlice; +use traits::structure::{ColSlice, Eye, Indexable}; use traits::geometry::Norm; use std::cmp::min; +/// Get the householder matrix corresponding to a reflexion to the hyperplane +/// defined by `vec̀ . It can be a reflexion contained in a subspace. +/// +/// # Arguments +/// * `dim` - the dimension of the space the resulting matrix operates in +/// * `start` - the starting dimension of the subspace of the reflexion +/// * `vec` - the vector defining the reflection. +fn householder_matrix, + V: Indexable> + (dim: uint, start: uint, vec: V) -> Mat { + let mut qk : Mat = Eye::new_identity(dim); + let stop = start + vec.shape(); + assert!(stop <= dim); + for j in range(start, stop) { + for i in range(start, stop) { + unsafe { + let vv = vec.unsafe_at(i) * vec.unsafe_at(j); + let qkij = qk.unsafe_at((i, j)); + qk.unsafe_set((i, j), qkij - vv - vv); + } + } + } + qk +} + /// QR decomposition using Householder reflections /// # Arguments /// * `m` - matrix to decompose @@ -13,42 +38,26 @@ pub fn decomp_qr(m: &DMat) -> (DMat, DMat) { let rows = m.nrows(); let cols = m.ncols(); assert!(rows >= cols); - let mut q : DMat = DMat::new_identity(rows); + let mut q : DMat = Eye::new_identity(rows); let mut r = m.clone(); - let subtract_reflection = |vec: DVec| -> DMat { - // FIXME: we don't handle the complex case here - let mut qk : DMat = DMat::new_identity(rows); - let start = rows - vec.at.len(); - for j in range(start, rows) { - for i in range(start, rows) { - unsafe { - let vv = vec.at_fast(i - start) * vec.at_fast(j - start); - let qkij = qk.at_fast(i, j); - qk.set_fast(i, j, qkij - vv - vv); - } - } - } - qk - }; - let iterations = min(rows - 1, cols); for ite in range(0u, iterations) { let mut v = r.col_slice(ite, ite, rows); let alpha = - if unsafe { v.at_fast(ite) } >= Zero::zero() { + if unsafe { v.unsafe_at(ite) } >= Zero::zero() { -Norm::norm(&v) } else { Norm::norm(&v) }; unsafe { - let x = v.at_fast(0); - v.set_fast(0, x - alpha); + let x = v.unsafe_at(0); + v.unsafe_set(0, x - alpha); } let _ = v.normalize(); - let qk = subtract_reflection(v); + let qk: DMat = householder_matrix(rows, 0, v); r = qk * r; q = q * Transpose::transpose_cpy(&qk); } diff --git a/src/na.rs b/src/na.rs index 415eb8f2..6832fb18 100644 --- a/src/na.rs +++ b/src/na.rs @@ -40,7 +40,8 @@ pub use traits::{ UniformSphereSample, AnyVec, VecExt, - ColSlice, RowSlice + ColSlice, RowSlice, + Eye }; pub use structs::{ diff --git a/src/structs/dmat.rs b/src/structs/dmat.rs index 8fac029c..40a3dea3 100644 --- a/src/structs/dmat.rs +++ b/src/structs/dmat.rs @@ -9,7 +9,7 @@ use traits::operations::ApproxEq; use std::mem; use structs::dvec::{DVec, DVecMulRhs}; use traits::operations::{Inv, Transpose, Mean, Cov}; -use traits::structure::{Cast, ColSlice, RowSlice}; +use traits::structure::{Cast, ColSlice, RowSlice, Eye, Indexable}; use std::fmt::{Show, Formatter, Result}; #[doc(hidden)] @@ -181,19 +181,19 @@ impl DMat { // FIXME: add a function to modify the dimension (to avoid useless allocations)? -impl DMat { +impl Eye for DMat { /// Builds an identity matrix. /// /// # Arguments /// * `dim` - The dimension of the matrix. A `dim`-dimensional matrix contains `dim * dim` /// components. #[inline] - pub fn new_identity(dim: uint) -> DMat { + fn new_identity(dim: uint) -> DMat { let mut res = DMat::new_zeros(dim, dim); for i in range(0u, dim) { let _1: N = One::one(); - res.set(i, i, _1); + res.set((i, i), _1); } res @@ -206,13 +206,16 @@ impl DMat { i + j * self.nrows } +} + +impl Indexable<(uint, uint), N> for DMat { /// Changes the value of a component of the matrix. /// /// # Arguments - /// * `row` - 0-based index of the line to be changed - /// * `col` - 0-based index of the column to be changed + /// * `rowcol` - 0-based tuple (row, col) to be changed #[inline] - pub fn set(&mut self, row: uint, col: uint, val: N) { + fn set(&mut self, rowcol: (uint, uint), val: N) { + let (row, col) = rowcol; assert!(row < self.nrows); assert!(col < self.ncols); @@ -222,7 +225,8 @@ impl DMat { /// Just like `set` without bounds checking. #[inline] - pub unsafe fn set_fast(&mut self, row: uint, col: uint, val: N) { + unsafe fn unsafe_set(&mut self, rowcol: (uint, uint), val: N) { + let (row, col) = rowcol; let offset = self.offset(row, col); *self.mij.as_mut_slice().unsafe_mut_ref(offset) = val } @@ -230,20 +234,38 @@ impl DMat { /// Reads the value of a component of the matrix. /// /// # Arguments - /// * `row` - 0-based index of the line to be read - /// * `col` - 0-based index of the column to be read + /// * `rowcol` - 0-based tuple (row, col) to be read #[inline] - pub fn at(&self, row: uint, col: uint) -> N { + fn at(&self, rowcol: (uint, uint)) -> N { + let (row, col) = rowcol; assert!(row < self.nrows); assert!(col < self.ncols); - unsafe { self.at_fast(row, col) } + unsafe { self.unsafe_at((row, col)) } } /// Just like `at` without bounds checking. #[inline] - pub unsafe fn at_fast(&self, row: uint, col: uint) -> N { + unsafe fn unsafe_at(&self, rowcol: (uint, uint)) -> N { + let (row, col) = rowcol; (*self.mij.as_slice().unsafe_ref(self.offset(row, col))).clone() } + + #[inline] + fn swap(&mut self, rowcol1: (uint, uint), rowcol2: (uint, uint)) { + let (row1, col1) = rowcol1; + let (row2, col2) = rowcol2; + let offset1 = self.offset(row1, col1); + let offset2 = self.offset(row2, col2); + let count = self.mij.len(); + assert!(offset1 < count); + assert!(offset1 < count); + self.mij.as_mut_slice().swap(offset1, offset2); + } + + fn shape(&self) -> (uint, uint) { + (self.nrows, self.ncols) + } + } impl + Add + Zero> DMatMulRhs> for DMat { @@ -258,10 +280,11 @@ impl + Add + Zero> DMatMulRhs> for DMat unsafe { for k in range(0u, left.ncols) { - acc = acc + left.at_fast(i, k) * right.at_fast(k, j); + acc = acc + + left.unsafe_at((i, k)) * right.unsafe_at((k, j)); } - res.set_fast(i, j, acc); + res.unsafe_set((i, j), acc); } } } @@ -282,7 +305,7 @@ DMatMulRhs> for DVec { for j in range(0u, left.ncols) { unsafe { - acc = acc + left.at_fast(i, j) * right.at_fast(j); + acc = acc + left.unsafe_at((i, j)) * right.unsafe_at(j); } } @@ -306,7 +329,7 @@ DVecMulRhs> for DMat { for j in range(0u, right.nrows) { unsafe { - acc = acc + left.at_fast(j) * right.at_fast(j, i); + acc = acc + left.unsafe_at(j) * right.unsafe_at((j, i)); } } @@ -335,7 +358,7 @@ Inv for DMat { assert!(self.nrows == self.ncols); let dim = self.nrows; - let mut res: DMat = DMat::new_identity(dim); + let mut res: DMat = Eye::new_identity(dim); let _0T: N = Zero::zero(); // inversion using Gauss-Jordan elimination @@ -347,7 +370,7 @@ Inv for DMat { let mut n0 = k; // index of a non-zero entry while n0 != dim { - if unsafe { self.at_fast(n0, k) } != _0T { + if unsafe { self.unsafe_at((n0, k)) } != _0T { break; } @@ -370,30 +393,30 @@ Inv for DMat { } unsafe { - let pivot = self.at_fast(k, k); + let pivot = self.unsafe_at((k, k)); for j in range(k, dim) { - let selfval = self.at_fast(k, j) / pivot; - self.set_fast(k, j, selfval); + let selfval = self.unsafe_at((k, j)) / pivot; + self.unsafe_set((k, j), selfval); } for j in range(0u, dim) { - let resval = res.at_fast(k, j) / pivot; - res.set_fast(k, j, resval); + let resval = res.unsafe_at((k, j)) / pivot; + res.unsafe_set((k, j), resval); } for l in range(0u, dim) { if l != k { - let normalizer = self.at_fast(l, k); + let normalizer = self.unsafe_at((l, k)); for j in range(k, dim) { - let selfval = self.at_fast(l, j) - self.at_fast(k, j) * normalizer; - self.set_fast(l, j, selfval); + let selfval = self.unsafe_at((l, j)) - self.unsafe_at((k, j)) * normalizer; + self.unsafe_set((l, j), selfval); } for j in range(0u, dim) { - let resval = res.at_fast(l, j) - res.at_fast(k, j) * normalizer; - res.set_fast(l, j, resval); + let resval = res.unsafe_at((l, j)) - res.unsafe_at((k, j)) * normalizer; + res.unsafe_set((l, j), resval); } } } @@ -422,7 +445,7 @@ impl Transpose for DMat { for i in range(0u, m.nrows) { for j in range(0u, m.ncols) { unsafe { - res.set_fast(j, i, m.at_fast(i, j)) + res.unsafe_set((j, i), m.unsafe_at((i, j))) } } } @@ -460,8 +483,8 @@ impl + Clone> Mean> for DMat { for i in range(0u, m.nrows) { for j in range(0u, m.ncols) { unsafe { - let acc = res.at_fast(j) + m.at_fast(i, j) * normalizer; - res.set_fast(j, acc); + let acc = res.unsafe_at(j) + m.unsafe_at((i, j)) * normalizer; + res.unsafe_set(j, acc); } } } @@ -482,7 +505,7 @@ impl + DMatDivRhs> + ToStr > Cov> for i in range(0u, m.nrows) { for j in range(0u, m.ncols) { unsafe { - centered.set_fast(i, j, m.at_fast(i, j) - mean.at_fast(j)); + centered.unsafe_set((i, j), m.unsafe_at((i, j)) - mean.unsafe_at(j)); } } } @@ -520,7 +543,7 @@ impl RowSlice> for DMat { let mut slice_idx = 0u; for col_id in range(col_start, col_end) { unsafe { - slice.set_fast(slice_idx, self.at_fast(row_id, col_id)); + slice.unsafe_set(slice_idx, self.unsafe_at((row_id, col_id))); } slice_idx += 1; } @@ -553,7 +576,7 @@ impl Show for DMat { fn fmt(&self, form:&mut Formatter) -> Result { for i in range(0u, self.nrows()) { for j in range(0u, self.ncols()) { - let _ = write!(form.buf, "{} ", self.at(i, j)); + let _ = write!(form.buf, "{} ", self.at((i, j))); } let _ = write!(form.buf, "\n"); } diff --git a/src/structs/dvec.rs b/src/structs/dvec.rs index b45a1a47..4da43229 100644 --- a/src/structs/dvec.rs +++ b/src/structs/dvec.rs @@ -9,7 +9,7 @@ use std::slice::{Items, MutItems}; use traits::operations::ApproxEq; use std::iter::FromIterator; use traits::geometry::{Dot, Norm}; -use traits::structure::{Iterable, IterableMut}; +use traits::structure::{Iterable, IterableMut, Indexable}; #[doc(hidden)] mod metal; @@ -48,11 +48,42 @@ impl DVec { } } -impl DVec { - /// Indexing without bounds checking. - pub unsafe fn at_fast(&self, i: uint) -> N { +impl Indexable for DVec { + + fn at(&self, i: uint) -> N { + assert!(i < self.at.len()); + unsafe { + self.unsafe_at(i) + } + } + + fn set(&mut self, i: uint, val: N) { + assert!(i < self.at.len()); + unsafe { + self.unsafe_set(i, val); + } + } + + fn swap(&mut self, i: uint, j: uint) { + assert!(i < self.at.len()); + assert!(j < self.at.len()); + self.at.as_mut_slice().swap(i, j); + } + + fn shape(&self) -> uint { + self.at.len() + } + + #[inline] + unsafe fn unsafe_at(&self, i: uint) -> N { (*self.at.as_slice().unsafe_ref(i)).clone() } + + #[inline] + unsafe fn unsafe_set(&mut self, i: uint, val: N) { + *self.at.as_mut_slice().unsafe_mut_ref(i) = val + } + } impl DVec { @@ -86,11 +117,6 @@ impl DVec { } } - #[inline] - pub unsafe fn set_fast(&mut self, i: uint, val: N) { - *self.at.as_mut_slice().unsafe_mut_ref(i) = val - } - /// Gets a reference to of this vector data. #[inline] pub fn as_vec<'r>(&'r self) -> &'r [N] { @@ -261,7 +287,7 @@ impl Dot for DVec { let mut res: N = Zero::zero(); for i in range(0u, a.at.len()) { - res = res + unsafe { a.at_fast(i) * b.at_fast(i) }; + res = res + unsafe { a.unsafe_at(i) * b.unsafe_at(i) }; } res @@ -272,7 +298,7 @@ impl Dot for DVec { let mut res: N = Zero::zero(); for i in range(0u, a.at.len()) { - res = res + unsafe { (a.at_fast(i) - b.at_fast(i)) * c.at_fast(i) }; + res = res + unsafe { (a.unsafe_at(i) - b.unsafe_at(i)) * c.unsafe_at(i) }; } res diff --git a/src/structs/mat.rs b/src/structs/mat.rs index c17831cf..0d180fcf 100644 --- a/src/structs/mat.rs +++ b/src/structs/mat.rs @@ -9,7 +9,8 @@ use std::slice::{Items, MutItems}; use structs::vec::{Vec1, Vec2, Vec3, Vec4, Vec5, Vec6, Vec1MulRhs, Vec4MulRhs, Vec5MulRhs, Vec6MulRhs}; -use traits::structure::{Cast, Row, Col, Iterable, IterableMut, Dim, Indexable}; +use traits::structure::{Cast, Row, Col, Iterable, IterableMut, Dim, Indexable, + Eye}; use traits::operations::{Absolute, Transpose, Inv, Outer}; use traits::geometry::{ToHomogeneous, FromHomogeneous}; @@ -34,6 +35,8 @@ pub struct Mat1 { pub m11: N } +eye_impl!(Mat1, 1, m11) + double_dispatch_binop_decl_trait!(Mat1, Mat1MulRhs) double_dispatch_binop_decl_trait!(Mat1, Mat1DivRhs) double_dispatch_binop_decl_trait!(Mat1, Mat1AddRhs) @@ -127,6 +130,8 @@ pub struct Mat2 { pub m12: N, pub m22: N } +eye_impl!(Mat2, 2, m11, m22) + double_dispatch_binop_decl_trait!(Mat2, Mat2MulRhs) double_dispatch_binop_decl_trait!(Mat2, Mat2DivRhs) double_dispatch_binop_decl_trait!(Mat2, Mat2AddRhs) @@ -225,6 +230,8 @@ pub struct Mat3 { pub m13: N, pub m23: N, pub m33: N } +eye_impl!(Mat3, 3, m11, m22, m33) + double_dispatch_binop_decl_trait!(Mat3, Mat3MulRhs) double_dispatch_binop_decl_trait!(Mat3, Mat3DivRhs) double_dispatch_binop_decl_trait!(Mat3, Mat3AddRhs) @@ -337,6 +344,8 @@ pub struct Mat4 { pub m14: N, pub m24: N, pub m34: N, pub m44: N } +eye_impl!(Mat4, 4, m11, m22, m33, m44) + double_dispatch_binop_decl_trait!(Mat4, Mat4MulRhs) double_dispatch_binop_decl_trait!(Mat4, Mat4DivRhs) double_dispatch_binop_decl_trait!(Mat4, Mat4AddRhs) @@ -501,6 +510,8 @@ pub struct Mat5 { pub m15: N, pub m25: N, pub m35: N, pub m45: N, pub m55: N } +eye_impl!(Mat5, 5, m11, m22, m33, m44, m55) + double_dispatch_binop_decl_trait!(Mat5, Mat5MulRhs) double_dispatch_binop_decl_trait!(Mat5, Mat5DivRhs) double_dispatch_binop_decl_trait!(Mat5, Mat5AddRhs) @@ -681,6 +692,8 @@ pub struct Mat6 { pub m16: N, pub m26: N, pub m36: N, pub m46: N, pub m56: N, pub m66: N } +eye_impl!(Mat6, 6, m11, m22, m33, m44, m55, m66) + double_dispatch_binop_decl_trait!(Mat6, Mat6MulRhs) double_dispatch_binop_decl_trait!(Mat6, Mat6DivRhs) double_dispatch_binop_decl_trait!(Mat6, Mat6AddRhs) diff --git a/src/structs/mat_macros.rs b/src/structs/mat_macros.rs index 2c13926d..6d08435b 100644 --- a/src/structs/mat_macros.rs +++ b/src/structs/mat_macros.rs @@ -98,6 +98,20 @@ macro_rules! scalar_add_impl( ) ) + +macro_rules! eye_impl( + ($t: ident, $ndim: expr, $($comp_diagN: ident),+) => ( + impl Eye for $t { + fn new_identity(dim: uint) -> $t { + assert!(dim == $ndim); + let mut eye: $t = Zero::zero(); + $(eye.$comp_diagN = One::one();)+ + eye + } + } + ) +) + macro_rules! scalar_sub_impl( ($t: ident, $n: ident, $trhs: ident, $comp0: ident $(,$compN: ident)*) => ( impl $trhs<$n, $t<$n>> for $n { @@ -193,6 +207,11 @@ macro_rules! indexable_impl( } } + #[inline] + fn shape(&self) -> (uint, uint) { + ($dim, $dim) + } + #[inline] unsafe fn unsafe_at(&self, (i, j): (uint, uint)) -> N { (*cast::transmute::<&$t, &[N, ..$dim * $dim]>(self).unsafe_ref(i + j * $dim)).clone() diff --git a/src/structs/spec/vec0.rs b/src/structs/spec/vec0.rs index 4843e3e1..ce493dda 100644 --- a/src/structs/spec/vec0.rs +++ b/src/structs/spec/vec0.rs @@ -25,6 +25,11 @@ impl Indexable for vec::Vec0 { fn set(&mut self, _: uint, _: N) { } + #[inline] + fn shape(&self) -> uint { + 0 + } + #[inline] fn swap(&mut self, _: uint, _: uint) { } diff --git a/src/structs/vec_macros.rs b/src/structs/vec_macros.rs index 0eeafb93..90e5ca9f 100644 --- a/src/structs/vec_macros.rs +++ b/src/structs/vec_macros.rs @@ -165,6 +165,11 @@ macro_rules! indexable_impl( } } + #[inline] + fn shape(&self) -> uint { + $dim + } + #[inline] fn swap(&mut self, i1: uint, i2: uint) { unsafe { diff --git a/src/traits/mod.rs b/src/traits/mod.rs index 9c774d80..1c471b45 100644 --- a/src/traits/mod.rs +++ b/src/traits/mod.rs @@ -6,7 +6,7 @@ pub use self::geometry::{AbsoluteRotate, Cross, CrossMatrix, Dot, FromHomogeneou pub use self::structure::{FloatVec, FloatVecExt, Basis, Cast, Col, Dim, Indexable, Iterable, IterableMut, Mat, Row, AnyVec, VecExt, - ColSlice, RowSlice}; + ColSlice, RowSlice, Eye}; pub use self::operations::{Absolute, ApproxEq, Cov, Inv, LMul, Mean, Outer, PartialOrd, RMul, ScalarAdd, ScalarSub, Transpose}; diff --git a/src/traits/structure.rs b/src/traits/structure.rs index 953f13b9..0e94e6ad 100644 --- a/src/traits/structure.rs +++ b/src/traits/structure.rs @@ -19,6 +19,12 @@ pub trait Mat : Row + Col + RMul + LMul { } impl + Col + RMul + LMul, R, C> Mat for M { } +/// Trait for constructing the identity matrix +pub trait Eye { + /// Return the identity matrix of specified dimension + fn new_identity(dim: uint) -> Self; +} + // XXX: we keep ScalarAdd and ScalarSub here to avoid trait impl conflict (overriding) between the // different Add/Sub traits. This is _so_ unfortunate… @@ -126,6 +132,9 @@ pub trait Indexable { /// Swaps the `i`-th element of `self` with its `j`-th element. fn swap(&mut self, i: Index, j: Index); + /// Returns the shape of the iterable range + fn shape(&self) -> Index; + /// Reads the `i`-th element of `self`. /// /// `i` is not checked.