From 663f8b3ccb4276bd8c7dc5f995d58625bd569d57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Crozet?= Date: Sat, 16 Aug 2014 13:12:08 +0200 Subject: [PATCH] Add a `Diag` to build, get and set a matrix diagonal. --- src/na.rs | 25 +++++++++++++++++++------ src/structs/dmat.rs | 38 +++++++++++++++++++++++++++++++++++++- src/structs/mat.rs | 8 +++++++- src/structs/mat_macros.rs | 33 +++++++++++++++++++++++++++++++++ src/traits/mod.rs | 2 +- src/traits/structure.rs | 12 ++++++++++++ 6 files changed, 109 insertions(+), 9 deletions(-) diff --git a/src/na.rs b/src/na.rs index 88ceb464..2979fde3 100644 --- a/src/na.rs +++ b/src/na.rs @@ -6,18 +6,22 @@ pub use traits::{PartialLess, PartialEqual, PartialGreater, NotComparable}; pub use traits::{ Absolute, AbsoluteRotate, + AnyVec, ApproxEq, - FloatVec, - FloatVecExt, Basis, Cast, Col, + ColSlice, RowSlice, Cov, Cross, CrossMatrix, Det, + Diag, Dim, Dot, + Eye, + FloatVec, + FloatVecExt, FromHomogeneous, Indexable, Inv, @@ -40,10 +44,7 @@ pub use traits::{ Translate, Translation, Transpose, UniformSphereSample, - AnyVec, - VecExt, - ColSlice, RowSlice, - Eye + VecExt }; pub use structs::{ @@ -735,6 +736,9 @@ pub fn mean>(observations: &M) -> N { // // +/* + * Eye + */ /// Construct the identity matrix for a given dimension #[inline(always)] pub fn new_identity(dim: uint) -> M { Eye::new_identity(dim) } @@ -763,6 +767,15 @@ pub fn orthonormal_subspace_basis(v: &V, f: |V| -> bool) { * Col */ +/* + * Diag + */ +/// Gets the diagonal of a square matrix. +#[inline(always)] +pub fn diag, V>(m: &M) -> V { + m.diag() +} + /* * Dim */ diff --git a/src/structs/dmat.rs b/src/structs/dmat.rs index 8c74c1c4..f57caa78 100644 --- a/src/structs/dmat.rs +++ b/src/structs/dmat.rs @@ -2,6 +2,7 @@ #![allow(missing_doc)] // we hide doc to not have to document the $trhs double dispatch trait. +use std::cmp; use std::rand::Rand; use std::rand; use std::num::{One, Zero}; @@ -9,7 +10,7 @@ use traits::operations::ApproxEq; use std::mem; use structs::dvec::{DVec, DVecMulRhs}; use traits::operations::{Inv, Transpose, Mean, Cov}; -use traits::structure::{Cast, ColSlice, RowSlice, Eye, Indexable}; +use traits::structure::{Cast, ColSlice, RowSlice, Diag, Eye, Indexable}; use std::fmt::{Show, Formatter, Result}; @@ -549,6 +550,41 @@ impl RowSlice> for DMat { } } +impl Diag> for DMat { + #[inline] + fn from_diag(diag: &DVec) -> DMat { + let mut res = DMat::new_zeros(diag.len(), diag.len()); + + res.set_diag(diag); + + res + } + + #[inline] + fn set_diag(&mut self, diag: &DVec) { + let smallest_dim = cmp::min(self.nrows, self.ncols); + + assert!(diag.len() == smallest_dim); + + for i in range(0, smallest_dim) { + unsafe { self.unsafe_set((i, i), diag.unsafe_at(i)) } + } + } + + #[inline] + fn diag(&self) -> DVec { + let smallest_dim = cmp::min(self.nrows, self.ncols); + + let mut diag: DVec = DVec::new_zeros(smallest_dim); + + for i in range(0, smallest_dim) { + unsafe { diag.unsafe_set(i, self.unsafe_at((i, i))) } + } + + diag + } +} + impl> ApproxEq for DMat { #[inline] fn approx_epsilon(_: Option>) -> N { diff --git a/src/structs/mat.rs b/src/structs/mat.rs index 7eeae95b..3ea96a9c 100644 --- a/src/structs/mat.rs +++ b/src/structs/mat.rs @@ -11,7 +11,7 @@ use structs::vec::{Vec1, Vec2, Vec3, Vec4, Vec5, Vec6, use structs::dvec::{DVec1, DVec2, DVec3, DVec4, DVec5, DVec6}; use traits::structure::{Cast, Row, Col, Iterable, IterableMut, Dim, Indexable, - Eye, ColSlice, RowSlice}; + Eye, ColSlice, RowSlice, Diag}; use traits::operations::{Absolute, Transpose, Inv, Outer}; use traits::geometry::{ToHomogeneous, FromHomogeneous}; @@ -120,6 +120,7 @@ row_impl!(Mat1, Vec1, 1) col_impl!(Mat1, Vec1, 1) col_slice_impl!(Mat1, Vec1, DVec1, 1) row_slice_impl!(Mat1, Vec1, DVec1, 1) +diag_impl!(Mat1, Vec1, 1) to_homogeneous_impl!(Mat1, Mat2, 1, 2) from_homogeneous_impl!(Mat1, Mat2, 1, 2) outer_impl!(Vec1, Mat1) @@ -221,6 +222,7 @@ row_impl!(Mat2, Vec2, 2) col_impl!(Mat2, Vec2, 2) col_slice_impl!(Mat2, Vec2, DVec2, 2) row_slice_impl!(Mat2, Vec2, DVec2, 2) +diag_impl!(Mat2, Vec2, 2) to_homogeneous_impl!(Mat2, Mat3, 2, 3) from_homogeneous_impl!(Mat2, Mat3, 2, 3) outer_impl!(Vec2, Mat2) @@ -336,6 +338,7 @@ approx_eq_impl!(Mat3) // (specialized) col_impl!(Mat3, Vec3, 3) col_slice_impl!(Mat3, Vec3, DVec3, 3) row_slice_impl!(Mat3, Vec3, DVec3, 3) +diag_impl!(Mat3, Vec3, 3) to_homogeneous_impl!(Mat3, Mat4, 3, 4) from_homogeneous_impl!(Mat3, Mat4, 3, 4) outer_impl!(Vec3, Mat3) @@ -503,6 +506,7 @@ row_impl!(Mat4, Vec4, 4) col_impl!(Mat4, Vec4, 4) col_slice_impl!(Mat4, Vec4, DVec4, 4) row_slice_impl!(Mat4, Vec4, DVec4, 4) +diag_impl!(Mat4, Vec4, 4) to_homogeneous_impl!(Mat4, Mat5, 4, 5) from_homogeneous_impl!(Mat4, Mat5, 4, 5) outer_impl!(Vec4, Mat4) @@ -686,6 +690,7 @@ row_impl!(Mat5, Vec5, 5) col_impl!(Mat5, Vec5, 5) col_slice_impl!(Mat5, Vec5, DVec5, 5) row_slice_impl!(Mat5, Vec5, DVec5, 5) +diag_impl!(Mat5, Vec5, 5) to_homogeneous_impl!(Mat5, Mat6, 5, 6) from_homogeneous_impl!(Mat5, Mat6, 5, 6) outer_impl!(Vec5, Mat5) @@ -921,4 +926,5 @@ row_impl!(Mat6, Vec6, 6) col_impl!(Mat6, Vec6, 6) col_slice_impl!(Mat6, Vec6, DVec6, 6) row_slice_impl!(Mat6, Vec6, DVec6, 6) +diag_impl!(Mat6, Vec6, 6) outer_impl!(Vec6, Mat6) diff --git a/src/structs/mat_macros.rs b/src/structs/mat_macros.rs index 9274d292..6233d69e 100644 --- a/src/structs/mat_macros.rs +++ b/src/structs/mat_macros.rs @@ -307,6 +307,39 @@ macro_rules! col_impl( ) ) +macro_rules! diag_impl( + ($t: ident, $tv: ident, $dim: expr) => ( + impl Diag<$tv> for $t { + #[inline] + fn from_diag(diag: &$tv) -> $t { + let mut res: $t = Zero::zero(); + + res.set_diag(diag); + + res + } + + #[inline] + fn set_diag(&mut self, diag: &$tv) { + for i in range(0, $dim) { + unsafe { self.unsafe_set((i, i), diag.unsafe_at(i)) } + } + } + + #[inline] + fn diag(&self) -> $tv { + let mut diag: $tv = Zero::zero(); + + for i in range(0, $dim) { + unsafe { diag.unsafe_set(i, self.unsafe_at((i, i))) } + } + + diag + } + } + ) +) + macro_rules! mat_mul_mat_impl( ($t: ident, $trhs: ident, $dim: expr) => ( impl $trhs> for $t { diff --git a/src/traits/mod.rs b/src/traits/mod.rs index 299f11ab..5fe3c9fa 100644 --- a/src/traits/mod.rs +++ b/src/traits/mod.rs @@ -6,7 +6,7 @@ pub use self::geometry::{AbsoluteRotate, Cross, CrossMatrix, Dot, FromHomogeneou pub use self::structure::{FloatVec, FloatVecExt, Basis, Cast, Col, Dim, Indexable, Iterable, IterableMut, Mat, Row, AnyVec, VecExt, - ColSlice, RowSlice, Eye}; + ColSlice, RowSlice, Diag, Eye}; pub use self::operations::{Absolute, ApproxEq, Cov, Det, Inv, LMul, Mean, Outer, PartialOrd, RMul, ScalarAdd, ScalarSub, ScalarMul, ScalarDiv, Transpose}; diff --git a/src/traits/structure.rs b/src/traits/structure.rs index 2d5a56ed..39e4a65c 100644 --- a/src/traits/structure.rs +++ b/src/traits/structure.rs @@ -115,6 +115,18 @@ pub trait Dim { fn dim(unused_self: Option) -> uint; } +/// Trait to get the diagonal of square matrices. +pub trait Diag { + /// Creates a new matrix with the given diagonal. + fn from_diag(diag: &V) -> Self; + + /// Sets the diagonal of this matrix. + fn set_diag(&mut self, diag: &V); + + /// The diagonal of this matrix. + fn diag(&self) -> V; +} + // FIXME: this trait should not be on nalgebra. // however, it is needed because std::ops::Index is (strangely) to poor: it // does not have a function to set values.