Merge pull request #609 from aplund/dev

Refactor row_sum() and column_sum() to cover more cases.
This commit is contained in:
Sébastien Crozet 2020-03-02 10:06:29 +01:00
commit 1d64de3822

View File

@ -1,5 +1,5 @@
use crate::{Scalar, Dim, Matrix, VectorN, RowVectorN, DefaultAllocator, U1, VectorSliceN}; use crate::{Scalar, Dim, Matrix, VectorN, RowVectorN, DefaultAllocator, U1, VectorSliceN};
use alga::general::{Field, SupersetOf}; use alga::general::{AdditiveMonoid, Field, SupersetOf};
use crate::storage::Storage; use crate::storage::Storage;
use crate::allocator::Allocator; use crate::allocator::Allocator;
@ -54,7 +54,7 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
} }
} }
impl<N: Scalar + Field + SupersetOf<f64>, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> { impl<N: Scalar + AdditiveMonoid, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
/* /*
* *
* Sum computation. * Sum computation.
@ -83,11 +83,15 @@ impl<N: Scalar + Field + SupersetOf<f64>, R: Dim, C: Dim, S: Storage<N, R, C>> M
/// # Example /// # Example
/// ///
/// ``` /// ```
/// # use nalgebra::{Matrix2x3, RowVector3}; /// # use nalgebra::{Matrix2x3, Matrix3x2};
/// # use nalgebra::{RowVector2, RowVector3};
/// ///
/// let m = Matrix2x3::new(1.0, 2.0, 3.0, /// let m = Matrix2x3::new(1.0, 2.0, 3.0,
/// 4.0, 5.0, 6.0); /// 4.0, 5.0, 6.0);
/// assert_eq!(m.row_sum(), RowVector3::new(5.0, 7.0, 9.0)); /// assert_eq!(m.row_sum(), RowVector3::new(5.0, 7.0, 9.0));
///
/// let mint = Matrix3x2::new(1,2,3,4,5,6);
/// assert_eq!(mint.row_sum(), RowVector2::new(9,12));
/// ``` /// ```
#[inline] #[inline]
pub fn row_sum(&self) -> RowVectorN<N, C> pub fn row_sum(&self) -> RowVectorN<N, C>
@ -100,11 +104,15 @@ impl<N: Scalar + Field + SupersetOf<f64>, R: Dim, C: Dim, S: Storage<N, R, C>> M
/// # Example /// # Example
/// ///
/// ``` /// ```
/// # use nalgebra::{Matrix2x3, Vector3}; /// # use nalgebra::{Matrix2x3, Matrix3x2};
/// # use nalgebra::{Vector2, Vector3};
/// ///
/// let m = Matrix2x3::new(1.0, 2.0, 3.0, /// let m = Matrix2x3::new(1.0, 2.0, 3.0,
/// 4.0, 5.0, 6.0); /// 4.0, 5.0, 6.0);
/// assert_eq!(m.row_sum_tr(), Vector3::new(5.0, 7.0, 9.0)); /// assert_eq!(m.row_sum_tr(), Vector3::new(5.0, 7.0, 9.0));
///
/// let mint = Matrix3x2::new(1,2,3,4,5,6);
/// assert_eq!(mint.row_sum_tr(), Vector2::new(9,12));
/// ``` /// ```
#[inline] #[inline]
pub fn row_sum_tr(&self) -> VectorN<N, C> pub fn row_sum_tr(&self) -> VectorN<N, C>
@ -117,21 +125,27 @@ impl<N: Scalar + Field + SupersetOf<f64>, R: Dim, C: Dim, S: Storage<N, R, C>> M
/// # Example /// # Example
/// ///
/// ``` /// ```
/// # use nalgebra::{Matrix2x3, Vector2}; /// # use nalgebra::{Matrix2x3, Matrix3x2};
/// # use nalgebra::{Vector2, Vector3};
/// ///
/// let m = Matrix2x3::new(1.0, 2.0, 3.0, /// let m = Matrix2x3::new(1.0, 2.0, 3.0,
/// 4.0, 5.0, 6.0); /// 4.0, 5.0, 6.0);
/// assert_eq!(m.column_sum(), Vector2::new(6.0, 15.0)); /// assert_eq!(m.column_sum(), Vector2::new(6.0, 15.0));
///
/// let mint = Matrix3x2::new(1,2,3,4,5,6);
/// assert_eq!(mint.column_sum(), Vector3::new(3,7,11));
/// ``` /// ```
#[inline] #[inline]
pub fn column_sum(&self) -> VectorN<N, R> pub fn column_sum(&self) -> VectorN<N, R>
where DefaultAllocator: Allocator<N, R> { where DefaultAllocator: Allocator<N, R> {
let nrows = self.data.shape().0; let nrows = self.data.shape().0;
self.compress_columns(VectorN::zeros_generic(nrows, U1), |out, col| { self.compress_columns(VectorN::zeros_generic(nrows, U1), |out, col| {
out.axpy(N::one(), &col, N::one()) *out += col;
}) })
} }
}
impl<N: Scalar + Field + SupersetOf<f64>, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
/* /*
* *
* Variance computation. * Variance computation.