Fix most tests.

This commit is contained in:
sebcrozet 2019-03-19 12:00:10 +01:00
parent e4748c69ce
commit 2f0d95bdbb
10 changed files with 124 additions and 60 deletions

View File

@ -770,6 +770,14 @@ impl<N: Scalar, R: Dim, C: Dim, S: StorageMut<N, R, C>> Matrix<N, R, C, S> {
}
}
// FIXME: rename `apply` to `apply_mut` and `apply_into` to `apply`?
/// Returns `self` with each of its components replaced by the result of a closure `f` applied on it.
#[inline]
pub fn apply_into<F: FnMut(N) -> N>(mut self, mut f: F) -> Self{
self.apply(f);
self
}
/// Replaces each component of `self` by the result of a closure `f` applied on it.
#[inline]
pub fn apply<F: FnMut(N) -> N>(&mut self, mut f: F) {

View File

@ -61,6 +61,23 @@ impl<N: Complex, D: Dim, S: Storage<N, D>> Reflection<N, D, S> {
}
}
// FIXME: naming convention: reflect_to, reflect_assign ?
/// Applies the reflection to the columns of `rhs`.
pub fn reflect_with_sign<R2: Dim, C2: Dim, S2>(&self, rhs: &mut Matrix<N, R2, C2, S2>, sign: N)
where
S2: StorageMut<N, R2, C2>,
ShapeConstraint: SameNumberOfRows<R2, D>,
{
for i in 0..rhs.ncols() {
// NOTE: we borrow the column twice here. First it is borrowed immutably for the
// dot product, and then mutably. Somehow, this allows significantly
// better optimizations of the dot product from the compiler.
let m_two = sign.scale(::convert(-2.0f64));
let factor = (self.axis.cdot(&rhs.column(i)) - self.bias) * m_two;
rhs.column_mut(i).axpy(factor, &self.axis, sign);
}
}
/// Applies the reflection to the rows of `lhs`.
pub fn reflect_rows<R2: Dim, C2: Dim, S2, S3>(
&self,
@ -81,4 +98,26 @@ impl<N: Complex, D: Dim, S: Storage<N, D>> Reflection<N, D, S> {
let m_two: N = ::convert(-2.0f64);
lhs.ger(m_two, &work, &self.axis.conjugate(), N::one());
}
/// Applies the reflection to the rows of `lhs`.
pub fn reflect_rows_with_sign<R2: Dim, C2: Dim, S2, S3>(
&self,
lhs: &mut Matrix<N, R2, C2, S2>,
work: &mut Vector<N, R2, S3>,
sign: N,
) where
S2: StorageMut<N, R2, C2>,
S3: StorageMut<N, R2>,
ShapeConstraint: DimEq<C2, D> + AreMultipliable<R2, C2, D, U1>,
DefaultAllocator: Allocator<N, D>
{
lhs.mul_to(&self.axis, work);
if !self.bias.is_zero() {
work.add_scalar_mut(-self.bias);
}
let m_two = sign.scale(::convert(-2.0f64));
lhs.ger(m_two, &work, &self.axis.conjugate(), sign);
}
}

View File

@ -200,11 +200,11 @@ where
let d = nrows.min(ncols);
let mut res = MatrixN::identity_generic(d, d);
res.set_diagonal(&self.diagonal);
res.set_diagonal(&self.diagonal.map(|e| N::from_real(e.modulus())));
let start = self.axis_shift();
res.slice_mut(start, (d.value() - 1, d.value() - 1))
.set_diagonal(&self.off_diagonal);
.set_diagonal(&self.off_diagonal.map(|e| N::from_real(e.modulus())));
res
}
@ -225,7 +225,14 @@ where
let refl = Reflection::new(Unit::new_unchecked(axis), N::zero());
let mut res_rows = res.slice_range_mut(i + shift.., i..);
refl.reflect(&mut res_rows);
let sign = if self.upper_diagonal {
self.diagonal[i].signum()
} else {
self.off_diagonal[i].signum()
};
refl.reflect_with_sign(&mut res_rows, sign);
}
res
@ -251,7 +258,14 @@ where
let refl = Reflection::new(Unit::new_unchecked(axis_packed), N::zero());
let mut res_rows = res.slice_range_mut(i.., i + shift..);
refl.reflect_rows(&mut res_rows, &mut work.rows_range_mut(i..));
let sign = if self.upper_diagonal {
self.off_diagonal[i].signum()
} else {
self.diagonal[i].signum()
};
refl.reflect_rows_with_sign(&mut res_rows, &mut work.rows_range_mut(i..), sign);
}
res

View File

@ -107,7 +107,7 @@ where DefaultAllocator: Allocator<N, D, D> + Allocator<N, D> + Allocator<N, DimD
self.hess.fill_lower_triangle(N::zero(), 2);
self.hess
.slice_mut((1, 0), (dim - 1, dim - 1))
.set_diagonal(&self.subdiag);
.set_diagonal(&self.subdiag.map(|e| N::from_real(e.modulus())));
self.hess
}
@ -122,13 +122,13 @@ where DefaultAllocator: Allocator<N, D, D> + Allocator<N, D> + Allocator<N, DimD
let mut res = self.hess.clone();
res.fill_lower_triangle(N::zero(), 2);
res.slice_mut((1, 0), (dim - 1, dim - 1))
.set_diagonal(&self.subdiag);
.set_diagonal(&self.subdiag.map(|e| N::from_real(e.modulus())));
res
}
/// Computes the orthogonal matrix `Q` of this decomposition.
pub fn q(&self) -> MatrixN<N, D> {
householder::assemble_q(&self.hess)
householder::assemble_q(&self.hess, self.subdiag.as_slice())
}
#[doc(hidden)]

View File

@ -60,10 +60,11 @@ pub fn clear_column_unchecked<N: Complex, R: Dim, C: Dim>(
if not_zero {
let refl = Reflection::new(Unit::new_unchecked(axis), N::zero());
let sign = reflection_norm.signum();
if let Some(mut work) = bilateral {
refl.reflect_rows(&mut right, &mut work);
refl.reflect_rows_with_sign(&mut right, &mut work, sign);
}
refl.reflect(&mut right.rows_range_mut(icol + shift..));
refl.reflect_with_sign(&mut right.rows_range_mut(icol + shift..), sign.conjugate());
}
}
@ -90,9 +91,10 @@ pub fn clear_row_unchecked<N: Complex, R: Dim, C: Dim>(
if not_zero {
let refl = Reflection::new(Unit::new_unchecked(axis), N::zero());
refl.reflect_rows(
refl.reflect_rows_with_sign(
&mut bottom.columns_range_mut(irow + shift..),
&mut work.rows_range_mut(irow + 1..),
reflection_norm.signum().conjugate(),
);
top.columns_range_mut(irow + shift..)
.tr_copy_from(&refl.axis());
@ -101,11 +103,11 @@ pub fn clear_row_unchecked<N: Complex, R: Dim, C: Dim>(
}
}
/// Computes the orthogonal transformation described by the elementary reflector axices stored on
/// Computes the orthogonal transformation described by the elementary reflector axii stored on
/// the lower-diagonal element of the given matrix.
/// matrices.
#[doc(hidden)]
pub fn assemble_q<N: Complex, D: Dim>(m: &MatrixN<N, D>) -> MatrixN<N, D>
pub fn assemble_q<N: Complex, D: Dim>(m: &MatrixN<N, D>, signs: &[N]) -> MatrixN<N, D>
where DefaultAllocator: Allocator<N, D, D> {
assert!(m.is_square());
let dim = m.data.shape().0;
@ -119,7 +121,7 @@ where DefaultAllocator: Allocator<N, D, D> {
let refl = Reflection::new(Unit::new_unchecked(axis), N::zero());
let mut res_rows = res.slice_range_mut(i + 1.., i..);
refl.reflect(&mut res_rows);
refl.reflect_with_sign(&mut res_rows, signs[i].signum());
}
res

View File

@ -1,5 +1,6 @@
#[cfg(feature = "serde-serialize")]
use serde::{Deserialize, Serialize};
use num::Zero;
use alga::general::Complex;
use allocator::{Allocator, Reallocator};
@ -83,7 +84,7 @@ where DefaultAllocator: Allocator<N, R, C> + Allocator<N, R> + Allocator<N, DimM
{
let (nrows, ncols) = self.qr.data.shape();
let mut res = self.qr.rows_generic(0, nrows.min(ncols)).upper_triangle();
res.set_diagonal(&self.diag);
res.set_diagonal(&self.diag.map(|e| N::from_real(e.modulus())));
res
}
@ -100,7 +101,7 @@ where DefaultAllocator: Allocator<N, R, C> + Allocator<N, R> + Allocator<N, DimM
let (nrows, ncols) = self.qr.data.shape();
let mut res = self.qr.resize_generic(nrows.min(ncols), ncols, N::zero());
res.fill_lower_triangle(N::zero(), 1);
res.set_diagonal(&self.diag);
res.set_diagonal(&self.diag.map(|e| N::from_real(e.modulus())));
res
}
@ -120,7 +121,7 @@ where DefaultAllocator: Allocator<N, R, C> + Allocator<N, R> + Allocator<N, DimM
let refl = Reflection::new(Unit::new_unchecked(axis), N::zero());
let mut res_rows = res.slice_range_mut(i.., i..);
refl.reflect(&mut res_rows);
refl.reflect_with_sign(&mut res_rows, self.diag[i].signum());
}
res
@ -157,7 +158,7 @@ where DefaultAllocator: Allocator<N, R, C> + Allocator<N, R> + Allocator<N, DimM
let refl = Reflection::new(Unit::new_unchecked(axis), N::zero());
let mut rhs_rows = rhs.rows_range_mut(i..);
refl.reflect(&mut rhs_rows);
refl.reflect_with_sign(&mut rhs_rows, self.diag[i].signum().conjugate());
}
}
}
@ -226,13 +227,13 @@ where DefaultAllocator: Allocator<N, D, D> + Allocator<N, D>
let coeff;
unsafe {
let diag = *self.diag.vget_unchecked(i);
let diag = self.diag.vget_unchecked(i).modulus();
if diag.is_zero() {
return false;
}
coeff = *b.vget_unchecked(i) / diag;
coeff = b.vget_unchecked(i).unscale(diag);
*b.vget_unchecked_mut(i) = coeff;
}

View File

@ -543,7 +543,7 @@ where
}
}
self.recompose().map(|m| m.transpose())
self.recompose().map(|m| m.conjugate_transpose())
}
}
@ -568,7 +568,7 @@ where
else {
match (&self.u, &self.v_t) {
(Some(u), Some(v_t)) => {
let mut ut_b = u.tr_mul(b);
let mut ut_b = u.conjugate().tr_mul(b);
for j in 0..ut_b.ncols() {
let mut col = ut_b.column_mut(j);
@ -583,7 +583,7 @@ where
}
}
Ok(v_t.tr_mul(&ut_b))
Ok(v_t.conjugate().tr_mul(&ut_b))
}
(None, None) => Err("SVD solve: U and V^t have not been computed."),
(None, _) => Err("SVD solve: U has not been computed."),

View File

@ -79,6 +79,7 @@ where DefaultAllocator: Allocator<N, D, D> + Allocator<N, DimDiff<D, U1>>
p.cgemv_symm(::convert(2.0), &m, &axis, N::zero());
let dot = axis.cdot(&p);
// p.axpy(-dot, &axis.conjugate(), N::one());
m.ger_symm(-N::one(), &p, &axis.conjugate(), N::one());
m.ger_symm(-N::one(), &axis, &p.conjugate(), N::one());
@ -106,32 +107,30 @@ where DefaultAllocator: Allocator<N, D, D> + Allocator<N, DimDiff<D, U1>>
let diag = self.diagonal();
let q = self.q();
(q, diag, self.off_diagonal)
(q, diag, self.off_diagonal.apply_into(|e| N::from_real(e.modulus())))
}
/// Retrieve the diagonal, and off diagonal elements of this decomposition.
pub fn unpack_tridiagonal(self) -> (VectorN<N, D>, VectorN<N, DimDiff<D, U1>>)
pub fn unpack_tridiagonal(mut self) -> (VectorN<N, D>, VectorN<N, DimDiff<D, U1>>)
where DefaultAllocator: Allocator<N, D> {
let diag = self.diagonal();
(diag, self.off_diagonal)
(diag, self.off_diagonal.apply_into(|e| N::from_real(e.modulus())))
}
/// The diagonal components of this decomposition.
pub fn diagonal(&self) -> VectorN<N, D>
where DefaultAllocator: Allocator<N, D> {
self.tri.diagonal()
}
where DefaultAllocator: Allocator<N, D> { self.tri.diagonal() }
/// The off-diagonal components of this decomposition.
pub fn off_diagonal(&self) -> &VectorN<N, DimDiff<D, U1>>
pub fn off_diagonal(&self) -> VectorN<N, DimDiff<D, U1>>
where DefaultAllocator: Allocator<N, D> {
&self.off_diagonal
self.off_diagonal.map(|e| N::from_real(e.modulus()))
}
/// Computes the orthogonal matrix `Q` of this decomposition.
pub fn q(&self) -> MatrixN<N, D> {
householder::assemble_q(&self.tri)
householder::assemble_q(&self.tri, self.off_diagonal.as_slice())
}
/// Recomputes the original symmetric matrix.
@ -141,8 +140,9 @@ where DefaultAllocator: Allocator<N, D, D> + Allocator<N, DimDiff<D, U1>>
self.tri.fill_upper_triangle(N::zero(), 2);
for i in 0..self.off_diagonal.len() {
self.tri[(i + 1, i)] = self.off_diagonal[i];
self.tri[(i, i + 1)] = self.off_diagonal[i].conjugate();
let val = N::from_real(self.off_diagonal[i].modulus());
self.tri[(i + 1, i)] = val;
self.tri[(i, i + 1)] = val;
}
&q * self.tri * q.conjugate_transpose()

View File

@ -41,7 +41,6 @@ quickcheck! {
relative_eq!(m, &u * d * &v_t, epsilon = 1.0e-7)
}
fn bidiagonal_static_square(m: Matrix4<RandComplex<f64>>) -> bool {
let m = m.map(|e| e.0);
let bidiagonal = m.bidiagonalize();

View File

@ -5,14 +5,15 @@ use na::{DMatrix, Matrix6};
mod quickcheck_tests {
use na::{
DMatrix, DVector, Matrix2, Matrix2x5, Matrix3, Matrix3x5, Matrix4, Matrix5x2, Matrix5x3,
Complex
};
use std::cmp;
use core::helper::{RandScalar, RandComplex};
quickcheck! {
/*
fn svd(m: DMatrix<f64>) -> bool {
fn svd(m: DMatrix<RandComplex<f64>>) -> bool {
let m = m.map(|e| e.0);
if m.len() > 0 {
let svd = m.clone().svd(true, true);
let recomp_m = svd.clone().recompose().unwrap();
@ -21,7 +22,7 @@ mod quickcheck_tests {
println!("{}{}", &m, &u * &ds * &v_t);
s.iter().all(|e| *e >= 0.0) &&
s.iter().all(|e| e.real() >= 0.0) &&
relative_eq!(&u * ds * &v_t, recomp_m, epsilon = 1.0e-5) &&
relative_eq!(m, recomp_m, epsilon = 1.0e-5)
}
@ -30,60 +31,62 @@ mod quickcheck_tests {
}
}
fn svd_static_5_3(m: Matrix5x3<f64>) -> bool {
fn svd_static_5_3(m: Matrix5x3<RandComplex<f64>>) -> bool {
let m = m.map(|e| e.0);
let svd = m.svd(true, true);
let (u, s, v_t) = (svd.u.unwrap(), svd.singular_values, svd.v_t.unwrap());
let ds = Matrix3::from_diagonal(&s);
s.iter().all(|e| *e >= 0.0) &&
s.iter().all(|e| e.real() >= 0.0) &&
relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5) &&
u.is_orthogonal(1.0e-5) &&
v_t.is_orthogonal(1.0e-5)
}
fn svd_static_5_2(m: Matrix5x2<f64>) -> bool {
fn svd_static_5_2(m: Matrix5x2<RandComplex<f64>>) -> bool {
let m = m.map(|e| e.0);
let svd = m.svd(true, true);
let (u, s, v_t) = (svd.u.unwrap(), svd.singular_values, svd.v_t.unwrap());
let ds = Matrix2::from_diagonal(&s);
s.iter().all(|e| *e >= 0.0) &&
s.iter().all(|e| e.real() >= 0.0) &&
relative_eq!(m, &u * ds * &v_t, epsilon = 1.0e-5) &&
u.is_orthogonal(1.0e-5) &&
v_t.is_orthogonal(1.0e-5)
}
fn svd_static_3_5(m: Matrix3x5<f64>) -> bool {
fn svd_static_3_5(m: Matrix3x5<RandComplex<f64>>) -> bool {
let m = m.map(|e| e.0);
let svd = m.svd(true, true);
let (u, s, v_t) = (svd.u.unwrap(), svd.singular_values, svd.v_t.unwrap());
let ds = Matrix3::from_diagonal(&s);
s.iter().all(|e| *e >= 0.0) &&
s.iter().all(|e| e.real() >= 0.0) &&
relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5)
}
fn svd_static_2_5(m: Matrix2x5<f64>) -> bool {
fn svd_static_2_5(m: Matrix2x5<RandComplex<f64>>) -> bool {
let m = m.map(|e| e.0);
let svd = m.svd(true, true);
let (u, s, v_t) = (svd.u.unwrap(), svd.singular_values, svd.v_t.unwrap());
let ds = Matrix2::from_diagonal(&s);
s.iter().all(|e| *e >= 0.0) &&
s.iter().all(|e| e.real() >= 0.0) &&
relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5)
}
fn svd_static_square(m: Matrix4<f64>) -> bool {
fn svd_static_square(m: Matrix4<RandComplex<f64>>) -> bool {
let m = m.map(|e| e.0);
let svd = m.svd(true, true);
let (u, s, v_t) = (svd.u.unwrap(), svd.singular_values, svd.v_t.unwrap());
let ds = Matrix4::from_diagonal(&s);
s.iter().all(|e| *e >= 0.0) &&
s.iter().all(|e| e.real() >= 0.0) &&
relative_eq!(m, u * ds * v_t, epsilon = 1.0e-5) &&
u.is_orthogonal(1.0e-5) &&
v_t.is_orthogonal(1.0e-5)
}
*/
fn svd_static_square_2x2(m: Matrix2<RandComplex<f64>>) -> bool {
let m = m.map(|e| e.0);
@ -102,8 +105,9 @@ mod quickcheck_tests {
v_t.is_orthogonal(1.0e-5)
}
/*
fn svd_pseudo_inverse(m: DMatrix<f64>) -> bool {
fn svd_pseudo_inverse(m: DMatrix<RandComplex<f64>>) -> bool {
let m = m.map(|e| e.0);
if m.len() > 0 {
let svd = m.clone().svd(true, true);
let pinv = svd.pseudo_inverse(1.0e-10).unwrap();
@ -125,13 +129,13 @@ mod quickcheck_tests {
fn svd_solve(n: usize, nb: usize) -> bool {
let n = cmp::max(1, cmp::min(n, 10));
let nb = cmp::min(nb, 10);
let m = DMatrix::<f64>::new_random(n, n);
let m = DMatrix::<RandComplex<f64>>::new_random(n, n).map(|e| e.0);
let svd = m.clone().svd(true, true);
if svd.rank(1.0e-7) == n {
let b1 = DVector::new_random(n);
let b2 = DMatrix::new_random(n, nb);
let b1 = DVector::<RandComplex<f64>>::new_random(n).map(|e| e.0);
let b2 = DMatrix::<RandComplex<f64>>::new_random(n, nb).map(|e| e.0);
let sol1 = svd.solve(&b1, 1.0e-7).unwrap();
let sol2 = svd.solve(&b2, 1.0e-7).unwrap();
@ -153,11 +157,10 @@ mod quickcheck_tests {
true
}
*/
}
}
/*
// Test proposed on the issue #176 of rulinalg.
#[test]
fn svd_singular() {
@ -357,5 +360,3 @@ fn svd_err() {
assert_eq!(Err("SVD recomposition: U and V^t have not been computed."), svd.clone().recompose());
assert_eq!(Err("SVD pseudo inverse: the epsilon must be non-negative."), svd.clone().pseudo_inverse(-1.0));
}
*/