Implement CsMatrix: axpy_cs, transpose, Add and Mul.

This commit is contained in:
sebcrozet 2018-10-20 22:27:18 +02:00
parent 0d24cf4dc0
commit 9fa3e7a769
9 changed files with 122 additions and 6 deletions

View File

@ -23,6 +23,7 @@ stdweb = [ "rand/stdweb" ]
arbitrary = [ "quickcheck" ] arbitrary = [ "quickcheck" ]
serde-serialize = [ "serde", "serde_derive", "num-complex/serde" ] serde-serialize = [ "serde", "serde_derive", "num-complex/serde" ]
abomonation-serialize = [ "abomonation" ] abomonation-serialize = [ "abomonation" ]
sparse = [ ]
debug = [ ] debug = [ ]
alloc = [ ] alloc = [ ]

View File

@ -81,10 +81,12 @@ an optimized set of tools for computer graphics and physics. Those features incl
#![deny(non_upper_case_globals)] #![deny(non_upper_case_globals)]
#![deny(unused_qualifications)] #![deny(unused_qualifications)]
#![deny(unused_results)] #![deny(unused_results)]
#![deny(missing_docs)] #![warn(missing_docs)] // FIXME: deny this
#![warn(incoherent_fundamental_impls)] #![warn(incoherent_fundamental_impls)]
#![doc(html_favicon_url = "http://nalgebra.org/img/favicon.ico", #![doc(
html_root_url = "http://nalgebra.org/rustdoc")] html_favicon_url = "http://nalgebra.org/img/favicon.ico",
html_root_url = "http://nalgebra.org/rustdoc"
)]
#![cfg_attr(not(feature = "std"), no_std)] #![cfg_attr(not(feature = "std"), no_std)]
#![cfg_attr(all(feature = "alloc", not(feature = "std")), feature(alloc))] #![cfg_attr(all(feature = "alloc", not(feature = "std")), feature(alloc))]
@ -126,6 +128,8 @@ pub mod base;
pub mod debug; pub mod debug;
pub mod geometry; pub mod geometry;
pub mod linalg; pub mod linalg;
#[cfg(feature = "sparse")]
pub mod sparse;
#[cfg(feature = "std")] #[cfg(feature = "std")]
#[deprecated( #[deprecated(
@ -135,6 +139,8 @@ pub use base as core;
pub use base::*; pub use base::*;
pub use geometry::*; pub use geometry::*;
pub use linalg::*; pub use linalg::*;
#[cfg(feature = "sparse")]
pub use sparse::*;
use std::cmp::{self, Ordering, PartialOrd}; use std::cmp::{self, Ordering, PartialOrd};

View File

@ -103,7 +103,7 @@ pub struct CsMatrix<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C> = CsVecStor
_phantoms: PhantomData<(N, R, C)>, _phantoms: PhantomData<(N, R, C)>,
} }
pub type CsVector<N, R, S> = CsMatrix<N, R, U1, S>; pub type CsVector<N, R, S = CsVecStorage<N, R, U1>> = CsMatrix<N, R, U1, S>;
impl<N: Scalar, R: Dim, C: Dim> CsMatrix<N, R, C> impl<N: Scalar, R: Dim, C: Dim> CsMatrix<N, R, C>
where where
@ -277,11 +277,14 @@ impl<N: Scalar + Zero + ClosedAdd + ClosedMul, D: Dim, S: StorageMut<N, D>> Vect
} }
} }
} else { } else {
// Needed to be sure even components not present on `x` are multiplied.
*self *= beta;
for i in 0..x.nvalues() { for i in 0..x.nvalues() {
unsafe { unsafe {
let k = x.data.row_index_unchecked(i); let k = x.data.row_index_unchecked(i);
let y = self.vget_unchecked_mut(k); let y = self.vget_unchecked_mut(k);
*y = alpha * *x.data.get_value_unchecked(i) + beta * *y; *y += alpha * *x.data.get_value_unchecked(i);
} }
} }
} }

View File

@ -1,3 +1,3 @@
pub use self::cs_matrix::CsMatrix; pub use self::cs_matrix::{CsMatrix, CsVector};
mod cs_matrix; mod cs_matrix;

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,18 @@
#![cfg_attr(rustfmt, rustfmt_skip)]
use na::{Matrix4x5, CsMatrix};
#[test]
fn cs_from_to_matrix() {
let m = Matrix4x5::new(
5.0, 6.0, 0.0, 8.0, 15.0,
9.0, 10.0, 11.0, 12.0, 0.0,
0.0, 0.0, 13.0, 0.0, 0.0,
0.0, 1.0, 4.0, 0.0, 14.0,
);
let cs: CsMatrix<_, _, _> = m.into();
let m2: Matrix4x5<_> = cs.into();
assert_eq!(m2, m);
}

18
tests/sparse/cs_matrix.rs Normal file
View File

@ -0,0 +1,18 @@
#![cfg_attr(rustfmt, rustfmt_skip)]
use na::{Matrix4x5, Matrix5x4, CsMatrix};
#[test]
fn cs_transpose() {
let m = Matrix4x5::new(
4.0, 1.0, 4.0, 0.0, 9.0,
5.0, 6.0, 0.0, 8.0, 10.0,
9.0, 10.0, 11.0, 12.0, 0.0,
0.0, 0.0, 1.0, 0.0, 10.0
);
let cs: CsMatrix<_, _, _> = m.into();
let cs_transposed: Matrix5x4<_> = cs.transpose().into();
assert_eq!(cs_transposed, m.transpose())
}

65
tests/sparse/cs_ops.rs Normal file
View File

@ -0,0 +1,65 @@
#![cfg_attr(rustfmt, rustfmt_skip)]
use na::{Matrix3x4, Matrix4x5, Matrix3x5, CsMatrix, Vector5, CsVector};
#[test]
fn axpy_cs() {
let mut v1 = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0);
let v2 = Vector5::new(10.0, 0.0, 30.0, 0.0, 50.0);
let expected = 5.0 * v2 + 10.0 * v1;
let cs: CsVector<_, _> = v2.into();
v1.axpy_cs(5.0, &cs, 10.0);
assert_eq!(v1, expected)
}
#[test]
fn cs_mat_mul() {
let m1 = Matrix3x4::new(
0.0, 1.0, 4.0, 0.0,
5.0, 6.0, 0.0, 8.0,
9.0, 10.0, 11.0, 12.0,
);
let m2 = Matrix4x5::new(
5.0, 6.0, 0.0, 8.0, 15.0,
9.0, 10.0, 11.0, 12.0, 0.0,
0.0, 0.0, 13.0, 0.0, 0.0,
0.0, 1.0, 4.0, 0.0, 14.0,
);
let sm1: CsMatrix<_, _, _> = m1.into();
let sm2: CsMatrix<_, _, _> = m2.into();
let mul = &sm1 * &sm2;
assert_eq!(Matrix3x5::from(mul), m1 * m2);
}
#[test]
fn cs_mat_add() {
let m1 = Matrix4x5::new(
4.0, 1.0, 4.0, 0.0, 9.0,
5.0, 6.0, 0.0, 8.0, 10.0,
9.0, 10.0, 11.0, 12.0, 0.0,
0.0, 0.0, 1.0, 0.0, 10.0
);
let m2 = Matrix4x5::new(
0.0, 1.0, 4.0, 0.0, 14.0,
5.0, 6.0, 0.0, 8.0, 15.0,
9.0, 10.0, 11.0, 12.0, 0.0,
0.0, 0.0, 13.0, 0.0, 0.0,
);
let sm1: CsMatrix<_, _, _> = m1.into();
let sm2: CsMatrix<_, _, _> = m2.into();
let mul = &sm1 + &sm2;
assert_eq!(Matrix4x5::from(mul), m1 + m2);
}

4
tests/sparse/mod.rs Normal file
View File

@ -0,0 +1,4 @@
mod cs_construction;
mod cs_conversion;
mod cs_matrix;
mod cs_ops;