From 39d20306f11b1fa4dd2af9ef5ec0604126fcd556 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Crozet?= Date: Fri, 2 Feb 2018 12:26:07 +0100 Subject: [PATCH] Add symmetric quadratic form computation. --- src/core/blas.rs | 45 ++++++++++++++++++++++++++++++++++++++++++--- src/core/ops.rs | 2 +- tests/core/blas.rs | 19 +++++++++++++++++++ tests/core/mod.rs | 4 ++-- 4 files changed, 64 insertions(+), 6 deletions(-) diff --git a/src/core/blas.rs b/src/core/blas.rs index feafbd02..a3f983cc 100644 --- a/src/core/blas.rs +++ b/src/core/blas.rs @@ -3,10 +3,11 @@ use num::{Zero, One, Signed}; use matrixmultiply; use alga::general::{ClosedMul, ClosedAdd}; -use core::{Scalar, Matrix, Vector}; +use core::{DefaultAllocator, Scalar, Matrix, SquareMatrix, Vector}; use core::dimension::{Dim, U1, U2, U3, U4, Dynamic}; use core::constraint::{ShapeConstraint, SameNumberOfRows, SameNumberOfColumns, AreMultipliable, DimEq}; use core::storage::{Storage, StorageMut}; +use core::allocator::Allocator; @@ -273,7 +274,7 @@ impl Vector #[inline] pub fn gemv_symm(&mut self, alpha: N, - a: &Matrix, + a: &SquareMatrix, x: &Vector, beta: N) where N: One, @@ -449,10 +450,48 @@ impl> Matrix assert!(dim1 == dim2 && dim1 == dim3, "ger: dimensions mismatch."); for j in 0 .. dim1 { - // FIXME: avoid bound checks. let val = unsafe { *y.vget_unchecked(j) }; let subdim = Dynamic::new(dim1 - j); + // FIXME: avoid bound checks. self.generic_slice_mut((j, j), (subdim, U1)).axpy(alpha * val, &x.rows_range(j ..), beta); } } } + +impl> SquareMatrix + where N: Scalar + Zero + One + ClosedAdd + ClosedMul { + + /// Computes the quadratic form `self = alpha * lrs * mid * lhs.transpose() + beta * self`. + pub fn quadform_symm(&mut self, + scratch: &mut Vector, + alpha: N, + mid: &SquareMatrix, + rhs: &Matrix, + beta: N) + where D2: Dim, R3: Dim, C3: Dim, + S2: StorageMut, + S3: Storage, + S4: Storage, + ShapeConstraint: DimEq + DimEq + DimEq, // FIXME: why is this one necessary? + DefaultAllocator: Allocator { + + + /* + scratch.gemv(N::one(), lhs, &mid.column(0), N::zero()); + self.ger_symm(alpha, &scratch, &lhs.column(0), beta); + + for j in 1 .. mid.ncols() { + scratch.gemv(N::one(), lhs, &mid.column(j), N::zero()); + self.ger_symm(alpha, &scratch, &lhs.column(j), N::one()); + } + */ + + scratch.gemv_symm(N::one(), mid, &rhs.column(0), N::zero()); + self.column_mut(0).gemv(alpha, rhs, &scratch, beta); + + for j in 1 .. mid.ncols() { + scratch.gemv_symm(N::one(), mid, &rhs.column(j), N::zero()); + self.slice_range_mut(j .., j).gemv(alpha, &rhs.rows_range(j ..), &scratch, N::one()); + } + } +} diff --git a/src/core/ops.rs b/src/core/ops.rs index f303e4d9..a1666d8d 100644 --- a/src/core/ops.rs +++ b/src/core/ops.rs @@ -6,7 +6,7 @@ use num::{Zero, One, Signed}; use alga::general::{ClosedMul, ClosedDiv, ClosedAdd, ClosedSub, ClosedNeg}; -use core::{DefaultAllocator, Scalar, Matrix, MatrixN, MatrixMN, MatrixSum}; +use core::{DefaultAllocator, Scalar, Matrix, MatrixN, MatrixMN, MatrixSum, SquareMatrix}; use core::dimension::{Dim, DimName, DimProd, DimMul}; use core::constraint::{ShapeConstraint, SameNumberOfRows, SameNumberOfColumns, AreMultipliable, DimEq}; use core::storage::{Storage, StorageMut, ContiguousStorageMut}; diff --git a/tests/core/blas.rs b/tests/core/blas.rs index 82239af0..79e5bffc 100644 --- a/tests/core/blas.rs +++ b/tests/core/blas.rs @@ -52,4 +52,23 @@ quickcheck! { relative_eq!(a1.lower_triangle(), a2) } + + fn quadform_symm(n: usize, alpha: f64, beta: f64) -> bool { + let n = cmp::max(1, cmp::min(n, 50)); + let lhs = DMatrix::::new_random(6, n); + let mut mid = DMatrix::::new_random(n, n); + let mut res = DMatrix::new_random(6, 6); + let mut scratch = Vector6::zeros(); + + mid.fill_upper_triangle_with_lower_triangle(); + + let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha; + + res.quadform_symm(&mut scratch, alpha, &lhs, &mid, beta); + res.fill_upper_triangle_with_lower_triangle(); + + println!("{}{}", res, expected); + + relative_eq!(res.lower_triangle(), expected.lower_triangle(), epsilon = 1.0e-7) + } } diff --git a/tests/core/mod.rs b/tests/core/mod.rs index 0e56e3b4..dac91598 100644 --- a/tests/core/mod.rs +++ b/tests/core/mod.rs @@ -1,8 +1,8 @@ // mod conversion; // mod edition; // mod matrix; -mod matrix_slice; -// mod blas; +// mod matrix_slice; +mod blas; // mod serde; // #[cfg(feature = "abomonation-serialize")] // mod abomonation;