Add a `Diag` to build, get and set a matrix diagonal.

This commit is contained in:
Sébastien Crozet 2014-08-16 13:12:08 +02:00
parent 40c9915870
commit 663f8b3ccb
6 changed files with 109 additions and 9 deletions

View File

@ -6,18 +6,22 @@ pub use traits::{PartialLess, PartialEqual, PartialGreater, NotComparable};
pub use traits::{ pub use traits::{
Absolute, Absolute,
AbsoluteRotate, AbsoluteRotate,
AnyVec,
ApproxEq, ApproxEq,
FloatVec,
FloatVecExt,
Basis, Basis,
Cast, Cast,
Col, Col,
ColSlice, RowSlice,
Cov, Cov,
Cross, Cross,
CrossMatrix, CrossMatrix,
Det, Det,
Diag,
Dim, Dim,
Dot, Dot,
Eye,
FloatVec,
FloatVecExt,
FromHomogeneous, FromHomogeneous,
Indexable, Indexable,
Inv, Inv,
@ -40,10 +44,7 @@ pub use traits::{
Translate, Translation, Translate, Translation,
Transpose, Transpose,
UniformSphereSample, UniformSphereSample,
AnyVec, VecExt
VecExt,
ColSlice, RowSlice,
Eye
}; };
pub use structs::{ pub use structs::{
@ -735,6 +736,9 @@ pub fn mean<N, M: Mean<N>>(observations: &M) -> N {
// //
// //
/*
* Eye
*/
/// Construct the identity matrix for a given dimension /// Construct the identity matrix for a given dimension
#[inline(always)] #[inline(always)]
pub fn new_identity<M: Eye>(dim: uint) -> M { Eye::new_identity(dim) } pub fn new_identity<M: Eye>(dim: uint) -> M { Eye::new_identity(dim) }
@ -763,6 +767,15 @@ pub fn orthonormal_subspace_basis<V: Basis>(v: &V, f: |V| -> bool) {
* Col<C> * Col<C>
*/ */
/*
* Diag<V>
*/
/// Gets the diagonal of a square matrix.
#[inline(always)]
pub fn diag<M: Diag<V>, V>(m: &M) -> V {
m.diag()
}
/* /*
* Dim * Dim
*/ */

View File

@ -2,6 +2,7 @@
#![allow(missing_doc)] // we hide doc to not have to document the $trhs double dispatch trait. #![allow(missing_doc)] // we hide doc to not have to document the $trhs double dispatch trait.
use std::cmp;
use std::rand::Rand; use std::rand::Rand;
use std::rand; use std::rand;
use std::num::{One, Zero}; use std::num::{One, Zero};
@ -9,7 +10,7 @@ use traits::operations::ApproxEq;
use std::mem; use std::mem;
use structs::dvec::{DVec, DVecMulRhs}; use structs::dvec::{DVec, DVecMulRhs};
use traits::operations::{Inv, Transpose, Mean, Cov}; use traits::operations::{Inv, Transpose, Mean, Cov};
use traits::structure::{Cast, ColSlice, RowSlice, Eye, Indexable}; use traits::structure::{Cast, ColSlice, RowSlice, Diag, Eye, Indexable};
use std::fmt::{Show, Formatter, Result}; use std::fmt::{Show, Formatter, Result};
@ -549,6 +550,41 @@ impl<N: Clone> RowSlice<DVec<N>> for DMat<N> {
} }
} }
impl<N: Clone + Zero> Diag<DVec<N>> for DMat<N> {
#[inline]
fn from_diag(diag: &DVec<N>) -> DMat<N> {
let mut res = DMat::new_zeros(diag.len(), diag.len());
res.set_diag(diag);
res
}
#[inline]
fn set_diag(&mut self, diag: &DVec<N>) {
let smallest_dim = cmp::min(self.nrows, self.ncols);
assert!(diag.len() == smallest_dim);
for i in range(0, smallest_dim) {
unsafe { self.unsafe_set((i, i), diag.unsafe_at(i)) }
}
}
#[inline]
fn diag(&self) -> DVec<N> {
let smallest_dim = cmp::min(self.nrows, self.ncols);
let mut diag: DVec<N> = DVec::new_zeros(smallest_dim);
for i in range(0, smallest_dim) {
unsafe { diag.unsafe_set(i, self.unsafe_at((i, i))) }
}
diag
}
}
impl<N: ApproxEq<N>> ApproxEq<N> for DMat<N> { impl<N: ApproxEq<N>> ApproxEq<N> for DMat<N> {
#[inline] #[inline]
fn approx_epsilon(_: Option<DMat<N>>) -> N { fn approx_epsilon(_: Option<DMat<N>>) -> N {

View File

@ -11,7 +11,7 @@ use structs::vec::{Vec1, Vec2, Vec3, Vec4, Vec5, Vec6,
use structs::dvec::{DVec1, DVec2, DVec3, DVec4, DVec5, DVec6}; use structs::dvec::{DVec1, DVec2, DVec3, DVec4, DVec5, DVec6};
use traits::structure::{Cast, Row, Col, Iterable, IterableMut, Dim, Indexable, use traits::structure::{Cast, Row, Col, Iterable, IterableMut, Dim, Indexable,
Eye, ColSlice, RowSlice}; Eye, ColSlice, RowSlice, Diag};
use traits::operations::{Absolute, Transpose, Inv, Outer}; use traits::operations::{Absolute, Transpose, Inv, Outer};
use traits::geometry::{ToHomogeneous, FromHomogeneous}; use traits::geometry::{ToHomogeneous, FromHomogeneous};
@ -120,6 +120,7 @@ row_impl!(Mat1, Vec1, 1)
col_impl!(Mat1, Vec1, 1) col_impl!(Mat1, Vec1, 1)
col_slice_impl!(Mat1, Vec1, DVec1, 1) col_slice_impl!(Mat1, Vec1, DVec1, 1)
row_slice_impl!(Mat1, Vec1, DVec1, 1) row_slice_impl!(Mat1, Vec1, DVec1, 1)
diag_impl!(Mat1, Vec1, 1)
to_homogeneous_impl!(Mat1, Mat2, 1, 2) to_homogeneous_impl!(Mat1, Mat2, 1, 2)
from_homogeneous_impl!(Mat1, Mat2, 1, 2) from_homogeneous_impl!(Mat1, Mat2, 1, 2)
outer_impl!(Vec1, Mat1) outer_impl!(Vec1, Mat1)
@ -221,6 +222,7 @@ row_impl!(Mat2, Vec2, 2)
col_impl!(Mat2, Vec2, 2) col_impl!(Mat2, Vec2, 2)
col_slice_impl!(Mat2, Vec2, DVec2, 2) col_slice_impl!(Mat2, Vec2, DVec2, 2)
row_slice_impl!(Mat2, Vec2, DVec2, 2) row_slice_impl!(Mat2, Vec2, DVec2, 2)
diag_impl!(Mat2, Vec2, 2)
to_homogeneous_impl!(Mat2, Mat3, 2, 3) to_homogeneous_impl!(Mat2, Mat3, 2, 3)
from_homogeneous_impl!(Mat2, Mat3, 2, 3) from_homogeneous_impl!(Mat2, Mat3, 2, 3)
outer_impl!(Vec2, Mat2) outer_impl!(Vec2, Mat2)
@ -336,6 +338,7 @@ approx_eq_impl!(Mat3)
// (specialized) col_impl!(Mat3, Vec3, 3) // (specialized) col_impl!(Mat3, Vec3, 3)
col_slice_impl!(Mat3, Vec3, DVec3, 3) col_slice_impl!(Mat3, Vec3, DVec3, 3)
row_slice_impl!(Mat3, Vec3, DVec3, 3) row_slice_impl!(Mat3, Vec3, DVec3, 3)
diag_impl!(Mat3, Vec3, 3)
to_homogeneous_impl!(Mat3, Mat4, 3, 4) to_homogeneous_impl!(Mat3, Mat4, 3, 4)
from_homogeneous_impl!(Mat3, Mat4, 3, 4) from_homogeneous_impl!(Mat3, Mat4, 3, 4)
outer_impl!(Vec3, Mat3) outer_impl!(Vec3, Mat3)
@ -503,6 +506,7 @@ row_impl!(Mat4, Vec4, 4)
col_impl!(Mat4, Vec4, 4) col_impl!(Mat4, Vec4, 4)
col_slice_impl!(Mat4, Vec4, DVec4, 4) col_slice_impl!(Mat4, Vec4, DVec4, 4)
row_slice_impl!(Mat4, Vec4, DVec4, 4) row_slice_impl!(Mat4, Vec4, DVec4, 4)
diag_impl!(Mat4, Vec4, 4)
to_homogeneous_impl!(Mat4, Mat5, 4, 5) to_homogeneous_impl!(Mat4, Mat5, 4, 5)
from_homogeneous_impl!(Mat4, Mat5, 4, 5) from_homogeneous_impl!(Mat4, Mat5, 4, 5)
outer_impl!(Vec4, Mat4) outer_impl!(Vec4, Mat4)
@ -686,6 +690,7 @@ row_impl!(Mat5, Vec5, 5)
col_impl!(Mat5, Vec5, 5) col_impl!(Mat5, Vec5, 5)
col_slice_impl!(Mat5, Vec5, DVec5, 5) col_slice_impl!(Mat5, Vec5, DVec5, 5)
row_slice_impl!(Mat5, Vec5, DVec5, 5) row_slice_impl!(Mat5, Vec5, DVec5, 5)
diag_impl!(Mat5, Vec5, 5)
to_homogeneous_impl!(Mat5, Mat6, 5, 6) to_homogeneous_impl!(Mat5, Mat6, 5, 6)
from_homogeneous_impl!(Mat5, Mat6, 5, 6) from_homogeneous_impl!(Mat5, Mat6, 5, 6)
outer_impl!(Vec5, Mat5) outer_impl!(Vec5, Mat5)
@ -921,4 +926,5 @@ row_impl!(Mat6, Vec6, 6)
col_impl!(Mat6, Vec6, 6) col_impl!(Mat6, Vec6, 6)
col_slice_impl!(Mat6, Vec6, DVec6, 6) col_slice_impl!(Mat6, Vec6, DVec6, 6)
row_slice_impl!(Mat6, Vec6, DVec6, 6) row_slice_impl!(Mat6, Vec6, DVec6, 6)
diag_impl!(Mat6, Vec6, 6)
outer_impl!(Vec6, Mat6) outer_impl!(Vec6, Mat6)

View File

@ -307,6 +307,39 @@ macro_rules! col_impl(
) )
) )
macro_rules! diag_impl(
($t: ident, $tv: ident, $dim: expr) => (
impl<N: Clone + Zero> Diag<$tv<N>> for $t<N> {
#[inline]
fn from_diag(diag: &$tv<N>) -> $t<N> {
let mut res: $t<N> = Zero::zero();
res.set_diag(diag);
res
}
#[inline]
fn set_diag(&mut self, diag: &$tv<N>) {
for i in range(0, $dim) {
unsafe { self.unsafe_set((i, i), diag.unsafe_at(i)) }
}
}
#[inline]
fn diag(&self) -> $tv<N> {
let mut diag: $tv<N> = Zero::zero();
for i in range(0, $dim) {
unsafe { diag.unsafe_set(i, self.unsafe_at((i, i))) }
}
diag
}
}
)
)
macro_rules! mat_mul_mat_impl( macro_rules! mat_mul_mat_impl(
($t: ident, $trhs: ident, $dim: expr) => ( ($t: ident, $trhs: ident, $dim: expr) => (
impl<N: Clone + Num> $trhs<N, $t<N>> for $t<N> { impl<N: Clone + Num> $trhs<N, $t<N>> for $t<N> {

View File

@ -6,7 +6,7 @@ pub use self::geometry::{AbsoluteRotate, Cross, CrossMatrix, Dot, FromHomogeneou
pub use self::structure::{FloatVec, FloatVecExt, Basis, Cast, Col, Dim, Indexable, pub use self::structure::{FloatVec, FloatVecExt, Basis, Cast, Col, Dim, Indexable,
Iterable, IterableMut, Mat, Row, AnyVec, VecExt, Iterable, IterableMut, Mat, Row, AnyVec, VecExt,
ColSlice, RowSlice, Eye}; ColSlice, RowSlice, Diag, Eye};
pub use self::operations::{Absolute, ApproxEq, Cov, Det, Inv, LMul, Mean, Outer, PartialOrd, RMul, pub use self::operations::{Absolute, ApproxEq, Cov, Det, Inv, LMul, Mean, Outer, PartialOrd, RMul,
ScalarAdd, ScalarSub, ScalarMul, ScalarDiv, Transpose}; ScalarAdd, ScalarSub, ScalarMul, ScalarDiv, Transpose};

View File

@ -115,6 +115,18 @@ pub trait Dim {
fn dim(unused_self: Option<Self>) -> uint; fn dim(unused_self: Option<Self>) -> uint;
} }
/// Trait to get the diagonal of square matrices.
pub trait Diag<V> {
/// Creates a new matrix with the given diagonal.
fn from_diag(diag: &V) -> Self;
/// Sets the diagonal of this matrix.
fn set_diag(&mut self, diag: &V);
/// The diagonal of this matrix.
fn diag(&self) -> V;
}
// FIXME: this trait should not be on nalgebra. // FIXME: this trait should not be on nalgebra.
// however, it is needed because std::ops::Index is (strangely) to poor: it // however, it is needed because std::ops::Index is (strangely) to poor: it
// does not have a function to set values. // does not have a function to set values.