Fix quadratic form computation.
For the moment only the version that does not make any assumption regarding symmetry is implemented.
This commit is contained in:
parent
39d20306f1
commit
1ee8a702ea
|
@ -462,36 +462,41 @@ impl<N, D1: Dim, S: StorageMut<N, D1, D1>> SquareMatrix<N, D1, S>
|
||||||
where N: Scalar + Zero + One + ClosedAdd + ClosedMul {
|
where N: Scalar + Zero + One + ClosedAdd + ClosedMul {
|
||||||
|
|
||||||
/// Computes the quadratic form `self = alpha * lrs * mid * lhs.transpose() + beta * self`.
|
/// Computes the quadratic form `self = alpha * lrs * mid * lhs.transpose() + beta * self`.
|
||||||
pub fn quadform_symm<D2, S2, R3, C3, S3, S4>(&mut self,
|
pub fn quadform_with_workspace<D2, S2, R3, C3, S3, D4, S4>(&mut self,
|
||||||
scratch: &mut Vector<N, D1, S2>,
|
work: &mut Vector<N, D2, S2>,
|
||||||
alpha: N,
|
alpha: N,
|
||||||
mid: &SquareMatrix<N, D2, S3>,
|
lhs: &Matrix<N, R3, C3, S3>,
|
||||||
rhs: &Matrix<N, R3, C3, S4>,
|
mid: &SquareMatrix<N, D4, S4>,
|
||||||
beta: N)
|
beta: N)
|
||||||
where D2: Dim, R3: Dim, C3: Dim,
|
where D2: Dim, R3: Dim, C3: Dim, D4: Dim,
|
||||||
S2: StorageMut<N, D1>,
|
S2: StorageMut<N, D2>,
|
||||||
S3: Storage<N, D2, D2>,
|
S3: Storage<N, R3, C3>,
|
||||||
S4: Storage<N, R3, C3>,
|
S4: Storage<N, D4, D4>,
|
||||||
ShapeConstraint: DimEq<D1, D2> + DimEq<D2, R3> + DimEq<D1, C3>, // FIXME: why is this one necessary?
|
ShapeConstraint: DimEq<D1, D2> +
|
||||||
|
DimEq<D1, R3> +
|
||||||
|
DimEq<D2, R3> +
|
||||||
|
DimEq<C3, D4> {
|
||||||
|
work.gemv(N::one(), lhs, &mid.column(0), N::zero());
|
||||||
|
self.ger(alpha, work, &lhs.column(0), beta);
|
||||||
|
|
||||||
|
for j in 1 .. mid.ncols() {
|
||||||
|
work.gemv(N::one(), lhs, &mid.column(j), N::zero());
|
||||||
|
self.ger(alpha, work, &lhs.column(j), N::one());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes the quadratic form `self = alpha * lrs * mid * lhs.transpose() + beta * self`.
|
||||||
|
pub fn quadform<R3, C3, S3, D4, S4>(&mut self,
|
||||||
|
alpha: N,
|
||||||
|
lhs: &Matrix<N, R3, C3, S3>,
|
||||||
|
mid: &SquareMatrix<N, D4, S4>,
|
||||||
|
beta: N)
|
||||||
|
where R3: Dim, C3: Dim, D4: Dim,
|
||||||
|
S3: Storage<N, R3, C3>,
|
||||||
|
S4: Storage<N, D4, D4>,
|
||||||
|
ShapeConstraint: DimEq<D1, D1> + DimEq<D1, R3> + DimEq<C3, D4>,
|
||||||
DefaultAllocator: Allocator<N, D1> {
|
DefaultAllocator: Allocator<N, D1> {
|
||||||
|
let mut work = unsafe { Vector::new_uninitialized_generic(self.data.shape().0, U1) };
|
||||||
|
self.quadform_with_workspace(&mut work, alpha, lhs, mid, beta)
|
||||||
/*
|
|
||||||
scratch.gemv(N::one(), lhs, &mid.column(0), N::zero());
|
|
||||||
self.ger_symm(alpha, &scratch, &lhs.column(0), beta);
|
|
||||||
|
|
||||||
for j in 1 .. mid.ncols() {
|
|
||||||
scratch.gemv(N::one(), lhs, &mid.column(j), N::zero());
|
|
||||||
self.ger_symm(alpha, &scratch, &lhs.column(j), N::one());
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
scratch.gemv_symm(N::one(), mid, &rhs.column(0), N::zero());
|
|
||||||
self.column_mut(0).gemv(alpha, rhs, &scratch, beta);
|
|
||||||
|
|
||||||
for j in 1 .. mid.ncols() {
|
|
||||||
scratch.gemv_symm(N::one(), mid, &rhs.column(j), N::zero());
|
|
||||||
self.slice_range_mut(j .., j).gemv(alpha, &rhs.rows_range(j ..), &scratch, N::one());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,22 +53,18 @@ quickcheck! {
|
||||||
relative_eq!(a1.lower_triangle(), a2)
|
relative_eq!(a1.lower_triangle(), a2)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quadform_symm(n: usize, alpha: f64, beta: f64) -> bool {
|
fn quadform(n: usize, alpha: f64, beta: f64) -> bool {
|
||||||
let n = cmp::max(1, cmp::min(n, 50));
|
let n = cmp::max(1, cmp::min(n, 50));
|
||||||
let lhs = DMatrix::<f64>::new_random(6, n);
|
let lhs = DMatrix::<f64>::new_random(6, n);
|
||||||
let mut mid = DMatrix::<f64>::new_random(n, n);
|
let mid = DMatrix::<f64>::new_random(n, n);
|
||||||
let mut res = DMatrix::new_random(6, 6);
|
let mut res = DMatrix::new_random(6, 6);
|
||||||
let mut scratch = Vector6::zeros();
|
|
||||||
|
|
||||||
mid.fill_upper_triangle_with_lower_triangle();
|
|
||||||
|
|
||||||
let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha;
|
let expected = &res * beta + &lhs * &mid * lhs.transpose() * alpha;
|
||||||
|
|
||||||
res.quadform_symm(&mut scratch, alpha, &lhs, &mid, beta);
|
res.quadform(alpha, &lhs, &mid , beta);
|
||||||
res.fill_upper_triangle_with_lower_triangle();
|
|
||||||
|
|
||||||
println!("{}{}", res, expected);
|
println!("{}{}", res, expected);
|
||||||
|
|
||||||
relative_eq!(res.lower_triangle(), expected.lower_triangle(), epsilon = 1.0e-7)
|
relative_eq!(res, expected, epsilon = 1.0e-7)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue