Add symmetric quadratic form computation.

This commit is contained in:
Sébastien Crozet 2018-02-02 12:26:07 +01:00
parent 94c1ab8e7b
commit 39d20306f1
4 changed files with 64 additions and 6 deletions

View File

@ -3,10 +3,11 @@ use num::{Zero, One, Signed};
use matrixmultiply; use matrixmultiply;
use alga::general::{ClosedMul, ClosedAdd}; use alga::general::{ClosedMul, ClosedAdd};
use core::{Scalar, Matrix, Vector}; use core::{DefaultAllocator, Scalar, Matrix, SquareMatrix, Vector};
use core::dimension::{Dim, U1, U2, U3, U4, Dynamic}; use core::dimension::{Dim, U1, U2, U3, U4, Dynamic};
use core::constraint::{ShapeConstraint, SameNumberOfRows, SameNumberOfColumns, AreMultipliable, DimEq}; use core::constraint::{ShapeConstraint, SameNumberOfRows, SameNumberOfColumns, AreMultipliable, DimEq};
use core::storage::{Storage, StorageMut}; use core::storage::{Storage, StorageMut};
use core::allocator::Allocator;
@ -273,7 +274,7 @@ impl<N, D: Dim, S> Vector<N, D, S>
#[inline] #[inline]
pub fn gemv_symm<D2: Dim, D3: Dim, SB, SC>(&mut self, pub fn gemv_symm<D2: Dim, D3: Dim, SB, SC>(&mut self,
alpha: N, alpha: N,
a: &Matrix<N, D2, D2, SB>, a: &SquareMatrix<N, D2, SB>,
x: &Vector<N, D3, SC>, x: &Vector<N, D3, SC>,
beta: N) beta: N)
where N: One, where N: One,
@ -449,10 +450,48 @@ impl<N, R1: Dim, C1: Dim, S: StorageMut<N, R1, C1>> Matrix<N, R1, C1, S>
assert!(dim1 == dim2 && dim1 == dim3, "ger: dimensions mismatch."); assert!(dim1 == dim2 && dim1 == dim3, "ger: dimensions mismatch.");
for j in 0 .. dim1 { for j in 0 .. dim1 {
// FIXME: avoid bound checks.
let val = unsafe { *y.vget_unchecked(j) }; let val = unsafe { *y.vget_unchecked(j) };
let subdim = Dynamic::new(dim1 - j); let subdim = Dynamic::new(dim1 - j);
// FIXME: avoid bound checks.
self.generic_slice_mut((j, j), (subdim, U1)).axpy(alpha * val, &x.rows_range(j ..), beta); self.generic_slice_mut((j, j), (subdim, U1)).axpy(alpha * val, &x.rows_range(j ..), beta);
} }
} }
} }
impl<N, D1: Dim, S: StorageMut<N, D1, D1>> SquareMatrix<N, D1, S>
where N: Scalar + Zero + One + ClosedAdd + ClosedMul {
/// Computes the quadratic form `self = alpha * lrs * mid * lhs.transpose() + beta * self`.
pub fn quadform_symm<D2, S2, R3, C3, S3, S4>(&mut self,
scratch: &mut Vector<N, D1, S2>,
alpha: N,
mid: &SquareMatrix<N, D2, S3>,
rhs: &Matrix<N, R3, C3, S4>,
beta: N)
where D2: Dim, R3: Dim, C3: Dim,
S2: StorageMut<N, D1>,
S3: Storage<N, D2, D2>,
S4: Storage<N, R3, C3>,
ShapeConstraint: DimEq<D1, D2> + DimEq<D2, R3> + DimEq<D1, C3>, // FIXME: why is this one necessary?
DefaultAllocator: Allocator<N, D1> {
/*
scratch.gemv(N::one(), lhs, &mid.column(0), N::zero());
self.ger_symm(alpha, &scratch, &lhs.column(0), beta);
for j in 1 .. mid.ncols() {
scratch.gemv(N::one(), lhs, &mid.column(j), N::zero());
self.ger_symm(alpha, &scratch, &lhs.column(j), N::one());
}
*/
scratch.gemv_symm(N::one(), mid, &rhs.column(0), N::zero());
self.column_mut(0).gemv(alpha, rhs, &scratch, beta);
for j in 1 .. mid.ncols() {
scratch.gemv_symm(N::one(), mid, &rhs.column(j), N::zero());
self.slice_range_mut(j .., j).gemv(alpha, &rhs.rows_range(j ..), &scratch, N::one());
}
}
}

View File

@ -6,7 +6,7 @@ use num::{Zero, One, Signed};
use alga::general::{ClosedMul, ClosedDiv, ClosedAdd, ClosedSub, ClosedNeg}; use alga::general::{ClosedMul, ClosedDiv, ClosedAdd, ClosedSub, ClosedNeg};
use core::{DefaultAllocator, Scalar, Matrix, MatrixN, MatrixMN, MatrixSum}; use core::{DefaultAllocator, Scalar, Matrix, MatrixN, MatrixMN, MatrixSum, SquareMatrix};
use core::dimension::{Dim, DimName, DimProd, DimMul}; use core::dimension::{Dim, DimName, DimProd, DimMul};
use core::constraint::{ShapeConstraint, SameNumberOfRows, SameNumberOfColumns, AreMultipliable, DimEq}; use core::constraint::{ShapeConstraint, SameNumberOfRows, SameNumberOfColumns, AreMultipliable, DimEq};
use core::storage::{Storage, StorageMut, ContiguousStorageMut}; use core::storage::{Storage, StorageMut, ContiguousStorageMut};

View File

@ -52,4 +52,23 @@ quickcheck! {
relative_eq!(a1.lower_triangle(), a2) relative_eq!(a1.lower_triangle(), a2)
} }
fn quadform_symm(n: usize, alpha: f64, beta: f64) -> bool {
let n = cmp::max(1, cmp::min(n, 50));
let lhs = DMatrix::<f64>::new_random(6, n);
let mut mid = DMatrix::<f64>::new_random(n, n);
let mut res = DMatrix::new_random(6, 6);
let mut scratch = Vector6::zeros();
mid.fill_upper_triangle_with_lower_triangle();
let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha;
res.quadform_symm(&mut scratch, alpha, &lhs, &mid, beta);
res.fill_upper_triangle_with_lower_triangle();
println!("{}{}", res, expected);
relative_eq!(res.lower_triangle(), expected.lower_triangle(), epsilon = 1.0e-7)
}
} }

View File

@ -1,8 +1,8 @@
// mod conversion; // mod conversion;
// mod edition; // mod edition;
// mod matrix; // mod matrix;
mod matrix_slice; // mod matrix_slice;
// mod blas; mod blas;
// mod serde; // mod serde;
// #[cfg(feature = "abomonation-serialize")] // #[cfg(feature = "abomonation-serialize")]
// mod abomonation; // mod abomonation;