From e6ee11617a027340b0de223a4bba43e88212a80f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Crozet?= Date: Sun, 19 Mar 2017 22:33:01 +0100 Subject: [PATCH] Add a method to compute the trace of a matrix. Fix #231. --- CHANGELOG.md | 6 ++++++ src/core/matrix.rs | 19 ++++++++++++++++++- tests/matrix.rs | 14 ++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b3b11a1b..74046bda 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,12 @@ documented here. This project adheres to [Semantic Versioning](http://semver.org/). +## [0.12.0] - WIP + +### Added + * `.trace()` that computes the trace of a matrix (i.e., the sum of its + diagonal elements.) + ## [0.11.0] The [website](http://nalgebra.org) has been fully rewritten and gives a good overview of all the added/modified features. diff --git a/src/core/matrix.rs b/src/core/matrix.rs index 004888e8..d87ddb3f 100644 --- a/src/core/matrix.rs +++ b/src/core/matrix.rs @@ -403,7 +403,7 @@ impl SquareMatrix /// Creates a square matrix with its diagonal set to `diag` and all other entries set to 0. #[inline] pub fn diagonal(&self) -> OwnedColumnVector { - assert!(self.is_square(), "Unable to transpose a non-square matrix in-place."); + assert!(self.is_square(), "Unable to get the diagonal of a non-square."); let dim = self.data.shape().0; let mut res = unsafe { OwnedColumnVector::::new_uninitialized_generic(dim, U1) }; @@ -414,6 +414,23 @@ impl SquareMatrix res } + + /// Computes a trace of a square matrix, i.e., the sum of its diagonal elements. + #[inline] + pub fn trace(&self) -> N + where N: Ring { + assert!(self.is_square(), "Cannot compute the trace of non-square matrix."); + + let dim = self.data.shape().0; + let mut res = N::zero(); + + for i in 0 .. dim.value() { + res += unsafe { *self.get_unchecked(i, i) }; + } + + res + + } } impl ColumnVector diff --git a/tests/matrix.rs b/tests/matrix.rs index 21e8c262..dcef35d7 100644 --- a/tests/matrix.rs +++ b/tests/matrix.rs @@ -308,6 +308,20 @@ fn simple_scalar_conversion() { assert_eq!(expected, a_u32); } +#[test] +#[should_panic] +fn trace_panic() { + let m = DMatrix::::new_random(2, 3); + let _ = m.trace(); +} + +#[test] +fn trace() { + let m = Matrix2::new(1.0, 20.0, + 30.0, 4.0); + assert_eq!(m.trace(), 5.0); +} + #[test] fn simple_transpose() { let a = Matrix2x3::new(1.0, 2.0, 3.0,