Add a `Diag` to build, get and set a matrix diagonal.

This commit is contained in:
Sébastien Crozet 2014-08-16 13:12:08 +02:00
parent 40c9915870
commit 663f8b3ccb
6 changed files with 109 additions and 9 deletions

View File

@ -6,18 +6,22 @@ pub use traits::{PartialLess, PartialEqual, PartialGreater, NotComparable};
pub use traits::{
Absolute,
AbsoluteRotate,
AnyVec,
ApproxEq,
FloatVec,
FloatVecExt,
Basis,
Cast,
Col,
ColSlice, RowSlice,
Cov,
Cross,
CrossMatrix,
Det,
Diag,
Dim,
Dot,
Eye,
FloatVec,
FloatVecExt,
FromHomogeneous,
Indexable,
Inv,
@ -40,10 +44,7 @@ pub use traits::{
Translate, Translation,
Transpose,
UniformSphereSample,
AnyVec,
VecExt,
ColSlice, RowSlice,
Eye
VecExt
};
pub use structs::{
@ -735,6 +736,9 @@ pub fn mean<N, M: Mean<N>>(observations: &M) -> N {
//
//
/*
* Eye
*/
/// Construct the identity matrix for a given dimension
#[inline(always)]
pub fn new_identity<M: Eye>(dim: uint) -> M { Eye::new_identity(dim) }
@ -763,6 +767,15 @@ pub fn orthonormal_subspace_basis<V: Basis>(v: &V, f: |V| -> bool) {
* Col<C>
*/
/*
* Diag<V>
*/
/// Gets the diagonal of a square matrix.
#[inline(always)]
pub fn diag<M: Diag<V>, V>(m: &M) -> V {
m.diag()
}
/*
* Dim
*/

View File

@ -2,6 +2,7 @@
#![allow(missing_doc)] // we hide doc to not have to document the $trhs double dispatch trait.
use std::cmp;
use std::rand::Rand;
use std::rand;
use std::num::{One, Zero};
@ -9,7 +10,7 @@ use traits::operations::ApproxEq;
use std::mem;
use structs::dvec::{DVec, DVecMulRhs};
use traits::operations::{Inv, Transpose, Mean, Cov};
use traits::structure::{Cast, ColSlice, RowSlice, Eye, Indexable};
use traits::structure::{Cast, ColSlice, RowSlice, Diag, Eye, Indexable};
use std::fmt::{Show, Formatter, Result};
@ -549,6 +550,41 @@ impl<N: Clone> RowSlice<DVec<N>> for DMat<N> {
}
}
impl<N: Clone + Zero> Diag<DVec<N>> for DMat<N> {
#[inline]
fn from_diag(diag: &DVec<N>) -> DMat<N> {
let mut res = DMat::new_zeros(diag.len(), diag.len());
res.set_diag(diag);
res
}
#[inline]
fn set_diag(&mut self, diag: &DVec<N>) {
let smallest_dim = cmp::min(self.nrows, self.ncols);
assert!(diag.len() == smallest_dim);
for i in range(0, smallest_dim) {
unsafe { self.unsafe_set((i, i), diag.unsafe_at(i)) }
}
}
#[inline]
fn diag(&self) -> DVec<N> {
let smallest_dim = cmp::min(self.nrows, self.ncols);
let mut diag: DVec<N> = DVec::new_zeros(smallest_dim);
for i in range(0, smallest_dim) {
unsafe { diag.unsafe_set(i, self.unsafe_at((i, i))) }
}
diag
}
}
impl<N: ApproxEq<N>> ApproxEq<N> for DMat<N> {
#[inline]
fn approx_epsilon(_: Option<DMat<N>>) -> N {

View File

@ -11,7 +11,7 @@ use structs::vec::{Vec1, Vec2, Vec3, Vec4, Vec5, Vec6,
use structs::dvec::{DVec1, DVec2, DVec3, DVec4, DVec5, DVec6};
use traits::structure::{Cast, Row, Col, Iterable, IterableMut, Dim, Indexable,
Eye, ColSlice, RowSlice};
Eye, ColSlice, RowSlice, Diag};
use traits::operations::{Absolute, Transpose, Inv, Outer};
use traits::geometry::{ToHomogeneous, FromHomogeneous};
@ -120,6 +120,7 @@ row_impl!(Mat1, Vec1, 1)
col_impl!(Mat1, Vec1, 1)
col_slice_impl!(Mat1, Vec1, DVec1, 1)
row_slice_impl!(Mat1, Vec1, DVec1, 1)
diag_impl!(Mat1, Vec1, 1)
to_homogeneous_impl!(Mat1, Mat2, 1, 2)
from_homogeneous_impl!(Mat1, Mat2, 1, 2)
outer_impl!(Vec1, Mat1)
@ -221,6 +222,7 @@ row_impl!(Mat2, Vec2, 2)
col_impl!(Mat2, Vec2, 2)
col_slice_impl!(Mat2, Vec2, DVec2, 2)
row_slice_impl!(Mat2, Vec2, DVec2, 2)
diag_impl!(Mat2, Vec2, 2)
to_homogeneous_impl!(Mat2, Mat3, 2, 3)
from_homogeneous_impl!(Mat2, Mat3, 2, 3)
outer_impl!(Vec2, Mat2)
@ -336,6 +338,7 @@ approx_eq_impl!(Mat3)
// (specialized) col_impl!(Mat3, Vec3, 3)
col_slice_impl!(Mat3, Vec3, DVec3, 3)
row_slice_impl!(Mat3, Vec3, DVec3, 3)
diag_impl!(Mat3, Vec3, 3)
to_homogeneous_impl!(Mat3, Mat4, 3, 4)
from_homogeneous_impl!(Mat3, Mat4, 3, 4)
outer_impl!(Vec3, Mat3)
@ -503,6 +506,7 @@ row_impl!(Mat4, Vec4, 4)
col_impl!(Mat4, Vec4, 4)
col_slice_impl!(Mat4, Vec4, DVec4, 4)
row_slice_impl!(Mat4, Vec4, DVec4, 4)
diag_impl!(Mat4, Vec4, 4)
to_homogeneous_impl!(Mat4, Mat5, 4, 5)
from_homogeneous_impl!(Mat4, Mat5, 4, 5)
outer_impl!(Vec4, Mat4)
@ -686,6 +690,7 @@ row_impl!(Mat5, Vec5, 5)
col_impl!(Mat5, Vec5, 5)
col_slice_impl!(Mat5, Vec5, DVec5, 5)
row_slice_impl!(Mat5, Vec5, DVec5, 5)
diag_impl!(Mat5, Vec5, 5)
to_homogeneous_impl!(Mat5, Mat6, 5, 6)
from_homogeneous_impl!(Mat5, Mat6, 5, 6)
outer_impl!(Vec5, Mat5)
@ -921,4 +926,5 @@ row_impl!(Mat6, Vec6, 6)
col_impl!(Mat6, Vec6, 6)
col_slice_impl!(Mat6, Vec6, DVec6, 6)
row_slice_impl!(Mat6, Vec6, DVec6, 6)
diag_impl!(Mat6, Vec6, 6)
outer_impl!(Vec6, Mat6)

View File

@ -307,6 +307,39 @@ macro_rules! col_impl(
)
)
macro_rules! diag_impl(
($t: ident, $tv: ident, $dim: expr) => (
impl<N: Clone + Zero> Diag<$tv<N>> for $t<N> {
#[inline]
fn from_diag(diag: &$tv<N>) -> $t<N> {
let mut res: $t<N> = Zero::zero();
res.set_diag(diag);
res
}
#[inline]
fn set_diag(&mut self, diag: &$tv<N>) {
for i in range(0, $dim) {
unsafe { self.unsafe_set((i, i), diag.unsafe_at(i)) }
}
}
#[inline]
fn diag(&self) -> $tv<N> {
let mut diag: $tv<N> = Zero::zero();
for i in range(0, $dim) {
unsafe { diag.unsafe_set(i, self.unsafe_at((i, i))) }
}
diag
}
}
)
)
macro_rules! mat_mul_mat_impl(
($t: ident, $trhs: ident, $dim: expr) => (
impl<N: Clone + Num> $trhs<N, $t<N>> for $t<N> {

View File

@ -6,7 +6,7 @@ pub use self::geometry::{AbsoluteRotate, Cross, CrossMatrix, Dot, FromHomogeneou
pub use self::structure::{FloatVec, FloatVecExt, Basis, Cast, Col, Dim, Indexable,
Iterable, IterableMut, Mat, Row, AnyVec, VecExt,
ColSlice, RowSlice, Eye};
ColSlice, RowSlice, Diag, Eye};
pub use self::operations::{Absolute, ApproxEq, Cov, Det, Inv, LMul, Mean, Outer, PartialOrd, RMul,
ScalarAdd, ScalarSub, ScalarMul, ScalarDiv, Transpose};

View File

@ -115,6 +115,18 @@ pub trait Dim {
fn dim(unused_self: Option<Self>) -> uint;
}
/// Trait to get the diagonal of square matrices.
pub trait Diag<V> {
/// Creates a new matrix with the given diagonal.
fn from_diag(diag: &V) -> Self;
/// Sets the diagonal of this matrix.
fn set_diag(&mut self, diag: &V);
/// The diagonal of this matrix.
fn diag(&self) -> V;
}
// FIXME: this trait should not be on nalgebra.
// however, it is needed because std::ops::Index is (strangely) to poor: it
// does not have a function to set values.