Add a method to compute the trace of a matrix.

Fix #231.
This commit is contained in:
Sébastien Crozet 2017-03-19 22:33:01 +01:00 committed by Sébastien Crozet
parent 4b9246ec10
commit e6ee11617a
3 changed files with 38 additions and 1 deletions

View File

@ -4,6 +4,12 @@ documented here.
This project adheres to [Semantic Versioning](http://semver.org/). This project adheres to [Semantic Versioning](http://semver.org/).
## [0.12.0] - WIP
### Added
* `.trace()` that computes the trace of a matrix (i.e., the sum of its
diagonal elements.)
## [0.11.0] ## [0.11.0]
The [website](http://nalgebra.org) has been fully rewritten and gives a good The [website](http://nalgebra.org) has been fully rewritten and gives a good
overview of all the added/modified features. overview of all the added/modified features.

View File

@ -403,7 +403,7 @@ impl<N, D: Dim, S> SquareMatrix<N, D, S>
/// Creates a square matrix with its diagonal set to `diag` and all other entries set to 0. /// Creates a square matrix with its diagonal set to `diag` and all other entries set to 0.
#[inline] #[inline]
pub fn diagonal(&self) -> OwnedColumnVector<N, D, S::Alloc> { pub fn diagonal(&self) -> OwnedColumnVector<N, D, S::Alloc> {
assert!(self.is_square(), "Unable to transpose a non-square matrix in-place."); assert!(self.is_square(), "Unable to get the diagonal of a non-square.");
let dim = self.data.shape().0; let dim = self.data.shape().0;
let mut res = unsafe { OwnedColumnVector::<N, D, S::Alloc>::new_uninitialized_generic(dim, U1) }; let mut res = unsafe { OwnedColumnVector::<N, D, S::Alloc>::new_uninitialized_generic(dim, U1) };
@ -414,6 +414,23 @@ impl<N, D: Dim, S> SquareMatrix<N, D, S>
res res
} }
/// Computes a trace of a square matrix, i.e., the sum of its diagonal elements.
#[inline]
pub fn trace(&self) -> N
where N: Ring {
assert!(self.is_square(), "Cannot compute the trace of non-square matrix.");
let dim = self.data.shape().0;
let mut res = N::zero();
for i in 0 .. dim.value() {
res += unsafe { *self.get_unchecked(i, i) };
}
res
}
} }
impl<N, D, S> ColumnVector<N, D, S> impl<N, D, S> ColumnVector<N, D, S>

View File

@ -308,6 +308,20 @@ fn simple_scalar_conversion() {
assert_eq!(expected, a_u32); assert_eq!(expected, a_u32);
} }
#[test]
#[should_panic]
fn trace_panic() {
let m = DMatrix::<f32>::new_random(2, 3);
let _ = m.trace();
}
#[test]
fn trace() {
let m = Matrix2::new(1.0, 20.0,
30.0, 4.0);
assert_eq!(m.trace(), 5.0);
}
#[test] #[test]
fn simple_transpose() { fn simple_transpose() {
let a = Matrix2x3::new(1.0, 2.0, 3.0, let a = Matrix2x3::new(1.0, 2.0, 3.0,