Add symmetric quadratic form computation.
This commit is contained in:
parent
94c1ab8e7b
commit
39d20306f1
|
@ -3,10 +3,11 @@ use num::{Zero, One, Signed};
|
|||
use matrixmultiply;
|
||||
use alga::general::{ClosedMul, ClosedAdd};
|
||||
|
||||
use core::{Scalar, Matrix, Vector};
|
||||
use core::{DefaultAllocator, Scalar, Matrix, SquareMatrix, Vector};
|
||||
use core::dimension::{Dim, U1, U2, U3, U4, Dynamic};
|
||||
use core::constraint::{ShapeConstraint, SameNumberOfRows, SameNumberOfColumns, AreMultipliable, DimEq};
|
||||
use core::storage::{Storage, StorageMut};
|
||||
use core::allocator::Allocator;
|
||||
|
||||
|
||||
|
||||
|
@ -273,7 +274,7 @@ impl<N, D: Dim, S> Vector<N, D, S>
|
|||
#[inline]
|
||||
pub fn gemv_symm<D2: Dim, D3: Dim, SB, SC>(&mut self,
|
||||
alpha: N,
|
||||
a: &Matrix<N, D2, D2, SB>,
|
||||
a: &SquareMatrix<N, D2, SB>,
|
||||
x: &Vector<N, D3, SC>,
|
||||
beta: N)
|
||||
where N: One,
|
||||
|
@ -449,10 +450,48 @@ impl<N, R1: Dim, C1: Dim, S: StorageMut<N, R1, C1>> Matrix<N, R1, C1, S>
|
|||
assert!(dim1 == dim2 && dim1 == dim3, "ger: dimensions mismatch.");
|
||||
|
||||
for j in 0 .. dim1 {
|
||||
// FIXME: avoid bound checks.
|
||||
let val = unsafe { *y.vget_unchecked(j) };
|
||||
let subdim = Dynamic::new(dim1 - j);
|
||||
// FIXME: avoid bound checks.
|
||||
self.generic_slice_mut((j, j), (subdim, U1)).axpy(alpha * val, &x.rows_range(j ..), beta);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, D1: Dim, S: StorageMut<N, D1, D1>> SquareMatrix<N, D1, S>
|
||||
where N: Scalar + Zero + One + ClosedAdd + ClosedMul {
|
||||
|
||||
/// Computes the quadratic form `self = alpha * lrs * mid * lhs.transpose() + beta * self`.
|
||||
pub fn quadform_symm<D2, S2, R3, C3, S3, S4>(&mut self,
|
||||
scratch: &mut Vector<N, D1, S2>,
|
||||
alpha: N,
|
||||
mid: &SquareMatrix<N, D2, S3>,
|
||||
rhs: &Matrix<N, R3, C3, S4>,
|
||||
beta: N)
|
||||
where D2: Dim, R3: Dim, C3: Dim,
|
||||
S2: StorageMut<N, D1>,
|
||||
S3: Storage<N, D2, D2>,
|
||||
S4: Storage<N, R3, C3>,
|
||||
ShapeConstraint: DimEq<D1, D2> + DimEq<D2, R3> + DimEq<D1, C3>, // FIXME: why is this one necessary?
|
||||
DefaultAllocator: Allocator<N, D1> {
|
||||
|
||||
|
||||
/*
|
||||
scratch.gemv(N::one(), lhs, &mid.column(0), N::zero());
|
||||
self.ger_symm(alpha, &scratch, &lhs.column(0), beta);
|
||||
|
||||
for j in 1 .. mid.ncols() {
|
||||
scratch.gemv(N::one(), lhs, &mid.column(j), N::zero());
|
||||
self.ger_symm(alpha, &scratch, &lhs.column(j), N::one());
|
||||
}
|
||||
*/
|
||||
|
||||
scratch.gemv_symm(N::one(), mid, &rhs.column(0), N::zero());
|
||||
self.column_mut(0).gemv(alpha, rhs, &scratch, beta);
|
||||
|
||||
for j in 1 .. mid.ncols() {
|
||||
scratch.gemv_symm(N::one(), mid, &rhs.column(j), N::zero());
|
||||
self.slice_range_mut(j .., j).gemv(alpha, &rhs.rows_range(j ..), &scratch, N::one());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ use num::{Zero, One, Signed};
|
|||
|
||||
use alga::general::{ClosedMul, ClosedDiv, ClosedAdd, ClosedSub, ClosedNeg};
|
||||
|
||||
use core::{DefaultAllocator, Scalar, Matrix, MatrixN, MatrixMN, MatrixSum};
|
||||
use core::{DefaultAllocator, Scalar, Matrix, MatrixN, MatrixMN, MatrixSum, SquareMatrix};
|
||||
use core::dimension::{Dim, DimName, DimProd, DimMul};
|
||||
use core::constraint::{ShapeConstraint, SameNumberOfRows, SameNumberOfColumns, AreMultipliable, DimEq};
|
||||
use core::storage::{Storage, StorageMut, ContiguousStorageMut};
|
||||
|
|
|
@ -52,4 +52,23 @@ quickcheck! {
|
|||
|
||||
relative_eq!(a1.lower_triangle(), a2)
|
||||
}
|
||||
|
||||
fn quadform_symm(n: usize, alpha: f64, beta: f64) -> bool {
|
||||
let n = cmp::max(1, cmp::min(n, 50));
|
||||
let lhs = DMatrix::<f64>::new_random(6, n);
|
||||
let mut mid = DMatrix::<f64>::new_random(n, n);
|
||||
let mut res = DMatrix::new_random(6, 6);
|
||||
let mut scratch = Vector6::zeros();
|
||||
|
||||
mid.fill_upper_triangle_with_lower_triangle();
|
||||
|
||||
let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha;
|
||||
|
||||
res.quadform_symm(&mut scratch, alpha, &lhs, &mid, beta);
|
||||
res.fill_upper_triangle_with_lower_triangle();
|
||||
|
||||
println!("{}{}", res, expected);
|
||||
|
||||
relative_eq!(res.lower_triangle(), expected.lower_triangle(), epsilon = 1.0e-7)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
// mod conversion;
|
||||
// mod edition;
|
||||
// mod matrix;
|
||||
mod matrix_slice;
|
||||
// mod blas;
|
||||
// mod matrix_slice;
|
||||
mod blas;
|
||||
// mod serde;
|
||||
// #[cfg(feature = "abomonation-serialize")]
|
||||
// mod abomonation;
|
||||
|
|
Loading…
Reference in New Issue