forked from M-Labs/nalgebra
Change the SVD methods to return a Result instead of panicking
This commit is contained in:
parent
4d7b215146
commit
8b1aa2078c
@ -485,89 +485,97 @@ where
|
|||||||
|
|
||||||
/// Rebuild the original matrix.
|
/// Rebuild the original matrix.
|
||||||
///
|
///
|
||||||
/// This is useful if some of the singular values have been manually modified. Panics if the
|
/// This is useful if some of the singular values have been manually modified.
|
||||||
/// right- and left- singular vectors have not been computed at construction-time.
|
/// Returns `Err` if the right- and left- singular vectors have not been
|
||||||
pub fn recompose(self) -> MatrixMN<N, R, C> {
|
/// computed at construction-time.
|
||||||
let mut u = self.u.expect("SVD recomposition: U has not been computed.");
|
pub fn recompose(self) -> Result<MatrixMN<N, R, C>, &'static str> {
|
||||||
let v_t = self.v_t
|
match (self.u, self.v_t) {
|
||||||
.expect("SVD recomposition: V^t has not been computed.");
|
(Some(_u), Some(_v_t)) => {
|
||||||
|
let mut u = _u;
|
||||||
|
let v_t = _v_t;
|
||||||
|
|
||||||
for i in 0..self.singular_values.len() {
|
for i in 0..self.singular_values.len() {
|
||||||
let val = self.singular_values[i];
|
let val = self.singular_values[i];
|
||||||
u.column_mut(i).mul_assign(val);
|
u.column_mut(i).mul_assign(val);
|
||||||
|
}
|
||||||
|
Ok(u * v_t)
|
||||||
|
}
|
||||||
|
(None, None) => Err("SVD recomposition: U and V^t have not been computed."),
|
||||||
|
(None, _) => Err("SVD recomposition: U has not been computed."),
|
||||||
|
(_, None) => Err("SVD recomposition: V^t has not been computed.")
|
||||||
}
|
}
|
||||||
|
|
||||||
u * v_t
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Computes the pseudo-inverse of the decomposed matrix.
|
/// Computes the pseudo-inverse of the decomposed matrix.
|
||||||
///
|
///
|
||||||
/// Any singular value smaller than `eps` is assumed to be zero.
|
/// Any singular value smaller than `eps` is assumed to be zero.
|
||||||
/// Panics if the right- and left- singular vectors have not been computed at
|
/// Returns `Err` if the right- and left- singular vectors have not
|
||||||
/// construction-time.
|
/// been computed at construction-time.
|
||||||
pub fn pseudo_inverse(mut self, eps: N) -> MatrixMN<N, C, R>
|
pub fn pseudo_inverse(mut self, eps: N) -> Result<MatrixMN<N, C, R>, &'static str>
|
||||||
where
|
where
|
||||||
DefaultAllocator: Allocator<N, C, R>,
|
DefaultAllocator: Allocator<N, C, R>,
|
||||||
{
|
{
|
||||||
assert!(
|
if eps < N::zero() {
|
||||||
eps >= N::zero(),
|
Err("SVD pseudo inverse: the epsilon must be non-negative.")
|
||||||
"SVD pseudo inverse: the epsilon must be non-negative."
|
|
||||||
);
|
|
||||||
for i in 0..self.singular_values.len() {
|
|
||||||
let val = self.singular_values[i];
|
|
||||||
|
|
||||||
if val > eps {
|
|
||||||
self.singular_values[i] = N::one() / val;
|
|
||||||
} else {
|
|
||||||
self.singular_values[i] = N::zero();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
for i in 0..self.singular_values.len() {
|
||||||
|
let val = self.singular_values[i];
|
||||||
|
|
||||||
self.recompose().transpose()
|
if val > eps {
|
||||||
|
self.singular_values[i] = N::one() / val;
|
||||||
|
} else {
|
||||||
|
self.singular_values[i] = N::zero();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.recompose().map(|m| m.transpose())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Solves the system `self * x = b` where `self` is the decomposed matrix and `x` the unknown.
|
/// Solves the system `self * x = b` where `self` is the decomposed matrix and `x` the unknown.
|
||||||
///
|
///
|
||||||
/// Any singular value smaller than `eps` is assumed to be zero.
|
/// Any singular value smaller than `eps` is assumed to be zero.
|
||||||
/// Returns `None` if the singular vectors `U` and `V` have not been computed.
|
/// Returns `Err` if the singular vectors `U` and `V` have not been computed.
|
||||||
// FIXME: make this more generic wrt the storage types and the dimensions for `b`.
|
// FIXME: make this more generic wrt the storage types and the dimensions for `b`.
|
||||||
pub fn solve<R2: Dim, C2: Dim, S2>(
|
pub fn solve<R2: Dim, C2: Dim, S2>(
|
||||||
&self,
|
&self,
|
||||||
b: &Matrix<N, R2, C2, S2>,
|
b: &Matrix<N, R2, C2, S2>,
|
||||||
eps: N,
|
eps: N,
|
||||||
) -> MatrixMN<N, C, C2>
|
) -> Result<MatrixMN<N, C, C2>, &'static str>
|
||||||
where
|
where
|
||||||
S2: Storage<N, R2, C2>,
|
S2: Storage<N, R2, C2>,
|
||||||
DefaultAllocator: Allocator<N, C, C2> + Allocator<N, DimMinimum<R, C>, C2>,
|
DefaultAllocator: Allocator<N, C, C2> + Allocator<N, DimMinimum<R, C>, C2>,
|
||||||
ShapeConstraint: SameNumberOfRows<R, R2>,
|
ShapeConstraint: SameNumberOfRows<R, R2>,
|
||||||
{
|
{
|
||||||
assert!(
|
if eps < N::zero() {
|
||||||
eps >= N::zero(),
|
Err("SVD solve: the epsilon must be non-negative.")
|
||||||
"SVD solve: the epsilon must be non-negative."
|
}
|
||||||
);
|
else {
|
||||||
let u = self.u
|
match (&self.u, &self.v_t) {
|
||||||
.as_ref()
|
(Some(u), Some(v_t)) => {
|
||||||
.expect("SVD solve: U has not been computed.");
|
let mut ut_b = u.tr_mul(b);
|
||||||
let v_t = self.v_t
|
|
||||||
.as_ref()
|
|
||||||
.expect("SVD solve: V^t has not been computed.");
|
|
||||||
|
|
||||||
let mut ut_b = u.tr_mul(b);
|
for j in 0..ut_b.ncols() {
|
||||||
|
let mut col = ut_b.column_mut(j);
|
||||||
|
|
||||||
for j in 0..ut_b.ncols() {
|
for i in 0..self.singular_values.len() {
|
||||||
let mut col = ut_b.column_mut(j);
|
let val = self.singular_values[i];
|
||||||
|
if val > eps {
|
||||||
|
col[i] /= val;
|
||||||
|
} else {
|
||||||
|
col[i] = N::zero();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for i in 0..self.singular_values.len() {
|
Ok(v_t.tr_mul(&ut_b))
|
||||||
let val = self.singular_values[i];
|
|
||||||
if val > eps {
|
|
||||||
col[i] /= val;
|
|
||||||
} else {
|
|
||||||
col[i] = N::zero();
|
|
||||||
}
|
}
|
||||||
|
(None, None) => Err("SVD solve: U and V^t have not been computed."),
|
||||||
|
(None, _) => Err("SVD solve: U has not been computed."),
|
||||||
|
(_, None) => Err("SVD solve: V^t has not been computed.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
v_t.tr_mul(&ut_b)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -623,7 +631,7 @@ where
|
|||||||
/// Computes the pseudo-inverse of this matrix.
|
/// Computes the pseudo-inverse of this matrix.
|
||||||
///
|
///
|
||||||
/// All singular values below `eps` are considered equal to 0.
|
/// All singular values below `eps` are considered equal to 0.
|
||||||
pub fn pseudo_inverse(self, eps: N) -> MatrixMN<N, C, R>
|
pub fn pseudo_inverse(self, eps: N) -> Result<MatrixMN<N, C, R>, &'static str>
|
||||||
where
|
where
|
||||||
DefaultAllocator: Allocator<N, C, R>,
|
DefaultAllocator: Allocator<N, C, R>,
|
||||||
{
|
{
|
||||||
|
@ -9,7 +9,7 @@ mod quickcheck_tests {
|
|||||||
fn svd(m: DMatrix<f64>) -> bool {
|
fn svd(m: DMatrix<f64>) -> bool {
|
||||||
if m.len() > 0 {
|
if m.len() > 0 {
|
||||||
let svd = m.clone().svd(true, true);
|
let svd = m.clone().svd(true, true);
|
||||||
let recomp_m = svd.clone().recompose();
|
let recomp_m = svd.clone().recompose().unwrap();
|
||||||
let (u, s, v_t) = (svd.u.unwrap(), svd.singular_values, svd.v_t.unwrap());
|
let (u, s, v_t) = (svd.u.unwrap(), svd.singular_values, svd.v_t.unwrap());
|
||||||
let ds = DMatrix::from_diagonal(&s);
|
let ds = DMatrix::from_diagonal(&s);
|
||||||
|
|
||||||
@ -90,7 +90,7 @@ mod quickcheck_tests {
|
|||||||
fn svd_pseudo_inverse(m: DMatrix<f64>) -> bool {
|
fn svd_pseudo_inverse(m: DMatrix<f64>) -> bool {
|
||||||
if m.len() > 0 {
|
if m.len() > 0 {
|
||||||
let svd = m.clone().svd(true, true);
|
let svd = m.clone().svd(true, true);
|
||||||
let pinv = svd.pseudo_inverse(1.0e-10);
|
let pinv = svd.pseudo_inverse(1.0e-10).unwrap();
|
||||||
|
|
||||||
if m.nrows() > m.ncols() {
|
if m.nrows() > m.ncols() {
|
||||||
println!("{}", &pinv * &m);
|
println!("{}", &pinv * &m);
|
||||||
@ -117,10 +117,10 @@ mod quickcheck_tests {
|
|||||||
let b1 = DVector::new_random(n);
|
let b1 = DVector::new_random(n);
|
||||||
let b2 = DMatrix::new_random(n, nb);
|
let b2 = DMatrix::new_random(n, nb);
|
||||||
|
|
||||||
let sol1 = svd.solve(&b1, 1.0e-7);
|
let sol1 = svd.solve(&b1, 1.0e-7).unwrap();
|
||||||
let sol2 = svd.solve(&b2, 1.0e-7);
|
let sol2 = svd.solve(&b2, 1.0e-7).unwrap();
|
||||||
|
|
||||||
let recomp = svd.recompose();
|
let recomp = svd.recompose().unwrap();
|
||||||
if !relative_eq!(m, recomp, epsilon = 1.0e-6) {
|
if !relative_eq!(m, recomp, epsilon = 1.0e-6) {
|
||||||
println!("{}{}", m, recomp);
|
println!("{}{}", m, recomp);
|
||||||
}
|
}
|
||||||
@ -262,22 +262,22 @@ fn svd_singular_horizontal() {
|
|||||||
fn svd_zeros() {
|
fn svd_zeros() {
|
||||||
let m = DMatrix::from_element(10, 10, 0.0);
|
let m = DMatrix::from_element(10, 10, 0.0);
|
||||||
let svd = m.clone().svd(true, true);
|
let svd = m.clone().svd(true, true);
|
||||||
assert_eq!(m, svd.recompose());
|
assert_eq!(Ok(m), svd.recompose());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn svd_identity() {
|
fn svd_identity() {
|
||||||
let m = DMatrix::<f64>::identity(10, 10);
|
let m = DMatrix::<f64>::identity(10, 10);
|
||||||
let svd = m.clone().svd(true, true);
|
let svd = m.clone().svd(true, true);
|
||||||
assert_eq!(m, svd.recompose());
|
assert_eq!(Ok(m), svd.recompose());
|
||||||
|
|
||||||
let m = DMatrix::<f64>::identity(10, 15);
|
let m = DMatrix::<f64>::identity(10, 15);
|
||||||
let svd = m.clone().svd(true, true);
|
let svd = m.clone().svd(true, true);
|
||||||
assert_eq!(m, svd.recompose());
|
assert_eq!(Ok(m), svd.recompose());
|
||||||
|
|
||||||
let m = DMatrix::<f64>::identity(15, 10);
|
let m = DMatrix::<f64>::identity(15, 10);
|
||||||
let svd = m.clone().svd(true, true);
|
let svd = m.clone().svd(true, true);
|
||||||
assert_eq!(m, svd.recompose());
|
assert_eq!(Ok(m), svd.recompose());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -294,7 +294,7 @@ fn svd_with_delimited_subproblem() {
|
|||||||
m[(8,8)] = 16.0; m[(3,9)] = 17.0;
|
m[(8,8)] = 16.0; m[(3,9)] = 17.0;
|
||||||
m[(9,9)] = 18.0;
|
m[(9,9)] = 18.0;
|
||||||
let svd = m.clone().svd(true, true);
|
let svd = m.clone().svd(true, true);
|
||||||
assert!(relative_eq!(m, svd.recompose(), epsilon = 1.0e-7));
|
assert!(relative_eq!(m, svd.recompose().unwrap(), epsilon = 1.0e-7));
|
||||||
|
|
||||||
// Rectangular versions.
|
// Rectangular versions.
|
||||||
let mut m = DMatrix::<f64>::from_element(15, 10, 0.0);
|
let mut m = DMatrix::<f64>::from_element(15, 10, 0.0);
|
||||||
@ -309,10 +309,10 @@ fn svd_with_delimited_subproblem() {
|
|||||||
m[(8,8)] = 16.0; m[(3,9)] = 17.0;
|
m[(8,8)] = 16.0; m[(3,9)] = 17.0;
|
||||||
m[(9,9)] = 18.0;
|
m[(9,9)] = 18.0;
|
||||||
let svd = m.clone().svd(true, true);
|
let svd = m.clone().svd(true, true);
|
||||||
assert!(relative_eq!(m, svd.recompose(), epsilon = 1.0e-7));
|
assert!(relative_eq!(m, svd.recompose().unwrap(), epsilon = 1.0e-7));
|
||||||
|
|
||||||
let svd = m.transpose().svd(true, true);
|
let svd = m.transpose().svd(true, true);
|
||||||
assert!(relative_eq!(m.transpose(), svd.recompose(), epsilon = 1.0e-7));
|
assert!(relative_eq!(m.transpose(), svd.recompose().unwrap(), epsilon = 1.0e-7));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -328,7 +328,15 @@ fn svd_fail() {
|
|||||||
println!("Singular values: {}", svd.singular_values);
|
println!("Singular values: {}", svd.singular_values);
|
||||||
println!("u: {:.5}", svd.u.unwrap());
|
println!("u: {:.5}", svd.u.unwrap());
|
||||||
println!("v: {:.5}", svd.v_t.unwrap());
|
println!("v: {:.5}", svd.v_t.unwrap());
|
||||||
let recomp = svd.recompose();
|
let recomp = svd.recompose().unwrap();
|
||||||
println!("{:.5}{:.5}", m, recomp);
|
println!("{:.5}{:.5}", m, recomp);
|
||||||
assert!(relative_eq!(m, recomp, epsilon = 1.0e-5));
|
assert!(relative_eq!(m, recomp, epsilon = 1.0e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn svd_err() {
|
||||||
|
let m = DMatrix::from_element(10, 10, 0.0);
|
||||||
|
let svd = m.clone().svd(false, false);
|
||||||
|
assert_eq!(Err("SVD recomposition: U and V^t have not been computed."), svd.clone().recompose());
|
||||||
|
assert_eq!(Err("SVD pseudo inverse: the epsilon must be non-negative."), svd.clone().pseudo_inverse(-1.0));
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user