forked from M-Labs/nalgebra
Merge pull request #823 from Andlon/sparse-rework
Sparse rework: nalgebra-sparse
This commit is contained in:
commit
adc82845d1
@ -57,10 +57,19 @@ jobs:
|
|||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: test
|
name: test
|
||||||
command: cargo test --features arbitrary --features serde-serialize --features abomonation-serialize --features sparse --features debug --features io --features compare --features libm
|
command: cargo test --features arbitrary --features serde-serialize --features abomonation-serialize --features sparse --features debug --features io --features compare --features libm --features proptest-support --features slow-tests
|
||||||
- run:
|
- run:
|
||||||
name: test nalgebra-glm
|
name: test nalgebra-glm
|
||||||
command: cargo test -p nalgebra-glm --features arbitrary --features serde-serialize --features abomonation-serialize --features sparse --features debug --features io --features compare --features libm
|
command: cargo test -p nalgebra-glm --features arbitrary --features serde-serialize --features abomonation-serialize --features sparse --features debug --features io --features compare --features libm --features slow-tests
|
||||||
|
- run:
|
||||||
|
name: test nalgebra-sparse
|
||||||
|
# Manifest-path is necessary because cargo otherwise won't correctly forward features
|
||||||
|
# We increase number of proptest cases to hopefully catch more potential bugs
|
||||||
|
command: PROPTEST_CASES=10000 cargo test --manifest-path=nalgebra-sparse/Cargo.toml --features compare,proptest-support
|
||||||
|
- run:
|
||||||
|
name: test nalgebra-sparse (slow tests)
|
||||||
|
# Unfortunately, the "slow-tests" take so much time that we need to run them with --release
|
||||||
|
command: PROPTEST_CASES=10000 cargo test --release --manifest-path=nalgebra-sparse/Cargo.toml --features compare,proptest-support,slow-tests slow
|
||||||
build-wasm:
|
build-wasm:
|
||||||
executor: rust-executor
|
executor: rust-executor
|
||||||
steps:
|
steps:
|
||||||
|
14
Cargo.toml
14
Cargo.toml
@ -35,6 +35,10 @@ io = [ "pest", "pest_derive" ]
|
|||||||
compare = [ "matrixcompare-core" ]
|
compare = [ "matrixcompare-core" ]
|
||||||
libm = [ "simba/libm" ]
|
libm = [ "simba/libm" ]
|
||||||
libm-force = [ "simba/libm_force" ]
|
libm-force = [ "simba/libm_force" ]
|
||||||
|
proptest-support = [ "proptest" ]
|
||||||
|
|
||||||
|
# This feature is only used for tests, and enables tests that require more time to run
|
||||||
|
slow-tests = []
|
||||||
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
@ -56,6 +60,7 @@ quickcheck = { version = "0.9", optional = true }
|
|||||||
pest = { version = "2", optional = true }
|
pest = { version = "2", optional = true }
|
||||||
pest_derive = { version = "2", optional = true }
|
pest_derive = { version = "2", optional = true }
|
||||||
matrixcompare-core = { version = "0.1", optional = true }
|
matrixcompare-core = { version = "0.1", optional = true }
|
||||||
|
proptest = { version = "0.10", optional = true, default-features = false, features = ["std"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
@ -68,10 +73,11 @@ rand_isaac = "0.2"
|
|||||||
#criterion = "0.2.10"
|
#criterion = "0.2.10"
|
||||||
|
|
||||||
# For matrix comparison macro
|
# For matrix comparison macro
|
||||||
matrixcompare = "0.1.3"
|
matrixcompare = "0.2.0"
|
||||||
|
itertools = "0.9"
|
||||||
|
|
||||||
[workspace]
|
[workspace]
|
||||||
members = [ "nalgebra-lapack", "nalgebra-glm" ]
|
members = [ "nalgebra-lapack", "nalgebra-glm", "nalgebra-sparse" ]
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
name = "nalgebra_bench"
|
name = "nalgebra_bench"
|
||||||
@ -80,3 +86,7 @@ path = "benches/lib.rs"
|
|||||||
|
|
||||||
[profile.bench]
|
[profile.bench]
|
||||||
lto = true
|
lto = true
|
||||||
|
|
||||||
|
[package.metadata.docs.rs]
|
||||||
|
# Enable certain features when building docs for docs.rs
|
||||||
|
features = [ "proptest-support", "compare" ]
|
||||||
|
27
nalgebra-sparse/Cargo.toml
Normal file
27
nalgebra-sparse/Cargo.toml
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
[package]
|
||||||
|
name = "nalgebra-sparse"
|
||||||
|
version = "0.1.0"
|
||||||
|
authors = [ "Andreas Longva", "Sébastien Crozet <developer@crozet.re>" ]
|
||||||
|
edition = "2018"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
proptest-support = ["proptest", "nalgebra/proptest-support"]
|
||||||
|
compare = [ "matrixcompare-core" ]
|
||||||
|
|
||||||
|
# Enable to enable running some tests that take a lot of time to run
|
||||||
|
slow-tests = []
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
nalgebra = { version="0.24", path = "../" }
|
||||||
|
num-traits = { version = "0.2", default-features = false }
|
||||||
|
proptest = { version = "0.10", optional = true }
|
||||||
|
matrixcompare-core = { version = "0.1.0", optional = true }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
itertools = "0.9"
|
||||||
|
matrixcompare = { version = "0.2.0", features = [ "proptest-support" ] }
|
||||||
|
nalgebra = { version="0.24", path = "../", features = ["compare"] }
|
||||||
|
|
||||||
|
[package.metadata.docs.rs]
|
||||||
|
# Enable certain features when building docs for docs.rs
|
||||||
|
features = [ "proptest-support", "compare" ]
|
124
nalgebra-sparse/src/convert/impl_std_ops.rs
Normal file
124
nalgebra-sparse/src/convert/impl_std_ops.rs
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
use crate::convert::serial::*;
|
||||||
|
use crate::coo::CooMatrix;
|
||||||
|
use crate::csc::CscMatrix;
|
||||||
|
use crate::csr::CsrMatrix;
|
||||||
|
use nalgebra::storage::Storage;
|
||||||
|
use nalgebra::{ClosedAdd, DMatrix, Dim, Matrix, Scalar};
|
||||||
|
use num_traits::Zero;
|
||||||
|
|
||||||
|
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CooMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero,
|
||||||
|
R: Dim,
|
||||||
|
C: Dim,
|
||||||
|
S: Storage<T, R, C>,
|
||||||
|
{
|
||||||
|
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
|
||||||
|
convert_dense_coo(matrix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> From<&'a CooMatrix<T>> for DMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero + ClosedAdd,
|
||||||
|
{
|
||||||
|
fn from(coo: &'a CooMatrix<T>) -> Self {
|
||||||
|
convert_coo_dense(coo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> From<&'a CooMatrix<T>> for CsrMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero + ClosedAdd,
|
||||||
|
{
|
||||||
|
fn from(matrix: &'a CooMatrix<T>) -> Self {
|
||||||
|
convert_coo_csr(matrix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> From<&'a CsrMatrix<T>> for CooMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero + ClosedAdd,
|
||||||
|
{
|
||||||
|
fn from(matrix: &'a CsrMatrix<T>) -> Self {
|
||||||
|
convert_csr_coo(matrix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CsrMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero,
|
||||||
|
R: Dim,
|
||||||
|
C: Dim,
|
||||||
|
S: Storage<T, R, C>,
|
||||||
|
{
|
||||||
|
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
|
||||||
|
convert_dense_csr(matrix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> From<&'a CsrMatrix<T>> for DMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero + ClosedAdd,
|
||||||
|
{
|
||||||
|
fn from(matrix: &'a CsrMatrix<T>) -> Self {
|
||||||
|
convert_csr_dense(matrix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> From<&'a CooMatrix<T>> for CscMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero + ClosedAdd,
|
||||||
|
{
|
||||||
|
fn from(matrix: &'a CooMatrix<T>) -> Self {
|
||||||
|
convert_coo_csc(matrix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> From<&'a CscMatrix<T>> for CooMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero,
|
||||||
|
{
|
||||||
|
fn from(matrix: &'a CscMatrix<T>) -> Self {
|
||||||
|
convert_csc_coo(matrix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CscMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero,
|
||||||
|
R: Dim,
|
||||||
|
C: Dim,
|
||||||
|
S: Storage<T, R, C>,
|
||||||
|
{
|
||||||
|
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
|
||||||
|
convert_dense_csc(matrix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> From<&'a CscMatrix<T>> for DMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero + ClosedAdd,
|
||||||
|
{
|
||||||
|
fn from(matrix: &'a CscMatrix<T>) -> Self {
|
||||||
|
convert_csc_dense(matrix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> From<&'a CscMatrix<T>> for CsrMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar,
|
||||||
|
{
|
||||||
|
fn from(matrix: &'a CscMatrix<T>) -> Self {
|
||||||
|
convert_csc_csr(matrix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> From<&'a CsrMatrix<T>> for CscMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar,
|
||||||
|
{
|
||||||
|
fn from(matrix: &'a CsrMatrix<T>) -> Self {
|
||||||
|
convert_csr_csc(matrix)
|
||||||
|
}
|
||||||
|
}
|
40
nalgebra-sparse/src/convert/mod.rs
Normal file
40
nalgebra-sparse/src/convert/mod.rs
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
//! Routines for converting between sparse matrix formats.
|
||||||
|
//!
|
||||||
|
//! Most users should instead use the provided `From` implementations to convert between matrix
|
||||||
|
//! formats. Note that `From` implementations may not be available between all combinations of
|
||||||
|
//! sparse matrices.
|
||||||
|
//!
|
||||||
|
//! The following example illustrates how to convert between matrix formats with the `From`
|
||||||
|
//! implementations.
|
||||||
|
//!
|
||||||
|
//! ```rust
|
||||||
|
//! use nalgebra_sparse::{csr::CsrMatrix, csc::CscMatrix, coo::CooMatrix};
|
||||||
|
//! use nalgebra::DMatrix;
|
||||||
|
//!
|
||||||
|
//! // Conversion from dense
|
||||||
|
//! let dense = DMatrix::<f64>::identity(9, 8);
|
||||||
|
//! let csr = CsrMatrix::from(&dense);
|
||||||
|
//! let csc = CscMatrix::from(&dense);
|
||||||
|
//! let coo = CooMatrix::from(&dense);
|
||||||
|
//!
|
||||||
|
//! // CSR <-> CSC
|
||||||
|
//! let _ = CsrMatrix::from(&csc);
|
||||||
|
//! let _ = CscMatrix::from(&csr);
|
||||||
|
//!
|
||||||
|
//! // CSR <-> COO
|
||||||
|
//! let _ = CooMatrix::from(&csr);
|
||||||
|
//! let _ = CsrMatrix::from(&coo);
|
||||||
|
//!
|
||||||
|
//! // CSC <-> COO
|
||||||
|
//! let _ = CooMatrix::from(&csc);
|
||||||
|
//! let _ = CscMatrix::from(&coo);
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! The routines available here are able to provide more specialized APIs, giving
|
||||||
|
//! more control over the conversion process. The routines are organized by backends.
|
||||||
|
//! Currently, only the [`serial`] backend is available.
|
||||||
|
//! In the future, backends that offer parallel routines may become available.
|
||||||
|
|
||||||
|
pub mod serial;
|
||||||
|
|
||||||
|
mod impl_std_ops;
|
427
nalgebra-sparse/src/convert/serial.rs
Normal file
427
nalgebra-sparse/src/convert/serial.rs
Normal file
@ -0,0 +1,427 @@
|
|||||||
|
//! Serial routines for converting between matrix formats.
|
||||||
|
//!
|
||||||
|
//! All routines in this module are single-threaded. At present these routines offer no
|
||||||
|
//! advantage over using the [`From`] trait, but future changes to the API might offer more
|
||||||
|
//! control to the user.
|
||||||
|
use std::ops::Add;
|
||||||
|
|
||||||
|
use num_traits::Zero;
|
||||||
|
|
||||||
|
use nalgebra::storage::Storage;
|
||||||
|
use nalgebra::{ClosedAdd, DMatrix, Dim, Matrix, Scalar};
|
||||||
|
|
||||||
|
use crate::coo::CooMatrix;
|
||||||
|
use crate::cs;
|
||||||
|
use crate::csc::CscMatrix;
|
||||||
|
use crate::csr::CsrMatrix;
|
||||||
|
|
||||||
|
/// Converts a dense matrix to [`CooMatrix`].
|
||||||
|
pub fn convert_dense_coo<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CooMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero,
|
||||||
|
R: Dim,
|
||||||
|
C: Dim,
|
||||||
|
S: Storage<T, R, C>,
|
||||||
|
{
|
||||||
|
let mut coo = CooMatrix::new(dense.nrows(), dense.ncols());
|
||||||
|
|
||||||
|
for (index, v) in dense.iter().enumerate() {
|
||||||
|
if v != &T::zero() {
|
||||||
|
// We use the fact that matrix iteration is guaranteed to be column-major
|
||||||
|
let i = index % dense.nrows();
|
||||||
|
let j = index / dense.nrows();
|
||||||
|
coo.push(i, j, v.inlined_clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
coo
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a [`CooMatrix`] to a dense matrix.
|
||||||
|
pub fn convert_coo_dense<T>(coo: &CooMatrix<T>) -> DMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero + ClosedAdd,
|
||||||
|
{
|
||||||
|
let mut output = DMatrix::repeat(coo.nrows(), coo.ncols(), T::zero());
|
||||||
|
for (i, j, v) in coo.triplet_iter() {
|
||||||
|
output[(i, j)] += v.inlined_clone();
|
||||||
|
}
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a [`CooMatrix`] to a [`CsrMatrix`].
|
||||||
|
pub fn convert_coo_csr<T>(coo: &CooMatrix<T>) -> CsrMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero,
|
||||||
|
{
|
||||||
|
let (offsets, indices, values) = convert_coo_cs(
|
||||||
|
coo.nrows(),
|
||||||
|
coo.row_indices(),
|
||||||
|
coo.col_indices(),
|
||||||
|
coo.values(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
|
||||||
|
// to see if it can be justified for performance reasons)
|
||||||
|
CsrMatrix::try_from_csr_data(coo.nrows(), coo.ncols(), offsets, indices, values)
|
||||||
|
.expect("Internal error: Invalid CSR data during COO->CSR conversion")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a [`CsrMatrix`] to a [`CooMatrix`].
|
||||||
|
pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T> {
|
||||||
|
let mut result = CooMatrix::new(csr.nrows(), csr.ncols());
|
||||||
|
for (i, j, v) in csr.triplet_iter() {
|
||||||
|
result.push(i, j, v.inlined_clone());
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a [`CsrMatrix`] to a dense matrix.
|
||||||
|
pub fn convert_csr_dense<T>(csr: &CsrMatrix<T>) -> DMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + Zero,
|
||||||
|
{
|
||||||
|
let mut output = DMatrix::zeros(csr.nrows(), csr.ncols());
|
||||||
|
|
||||||
|
for (i, j, v) in csr.triplet_iter() {
|
||||||
|
output[(i, j)] += v.inlined_clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a dense matrix to a [`CsrMatrix`].
|
||||||
|
pub fn convert_dense_csr<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CsrMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero,
|
||||||
|
R: Dim,
|
||||||
|
C: Dim,
|
||||||
|
S: Storage<T, R, C>,
|
||||||
|
{
|
||||||
|
let mut row_offsets = Vec::with_capacity(dense.nrows() + 1);
|
||||||
|
let mut col_idx = Vec::new();
|
||||||
|
let mut values = Vec::new();
|
||||||
|
|
||||||
|
// We have to iterate row-by-row to build the CSR matrix, which is at odds with
|
||||||
|
// nalgebra's column-major storage. The alternative would be to perform an initial sweep
|
||||||
|
// to count number of non-zeros per row.
|
||||||
|
row_offsets.push(0);
|
||||||
|
for i in 0..dense.nrows() {
|
||||||
|
for j in 0..dense.ncols() {
|
||||||
|
let v = dense.index((i, j));
|
||||||
|
if v != &T::zero() {
|
||||||
|
col_idx.push(j);
|
||||||
|
values.push(v.inlined_clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
row_offsets.push(col_idx.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Consider circumventing the data validity check here
|
||||||
|
// (would require unsafe, should benchmark)
|
||||||
|
CsrMatrix::try_from_csr_data(dense.nrows(), dense.ncols(), row_offsets, col_idx, values)
|
||||||
|
.expect("Internal error: Invalid CsrMatrix format during dense-> CSR conversion")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a [`CooMatrix`] to a [`CscMatrix`].
|
||||||
|
pub fn convert_coo_csc<T>(coo: &CooMatrix<T>) -> CscMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero,
|
||||||
|
{
|
||||||
|
let (offsets, indices, values) = convert_coo_cs(
|
||||||
|
coo.ncols(),
|
||||||
|
coo.col_indices(),
|
||||||
|
coo.row_indices(),
|
||||||
|
coo.values(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
|
||||||
|
// to see if it can be justified for performance reasons)
|
||||||
|
CscMatrix::try_from_csc_data(coo.nrows(), coo.ncols(), offsets, indices, values)
|
||||||
|
.expect("Internal error: Invalid CSC data during COO->CSC conversion")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a [`CscMatrix`] to a [`CooMatrix`].
|
||||||
|
pub fn convert_csc_coo<T>(csc: &CscMatrix<T>) -> CooMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar,
|
||||||
|
{
|
||||||
|
let mut coo = CooMatrix::new(csc.nrows(), csc.ncols());
|
||||||
|
for (i, j, v) in csc.triplet_iter() {
|
||||||
|
coo.push(i, j, v.inlined_clone());
|
||||||
|
}
|
||||||
|
coo
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a [`CscMatrix`] to a dense matrix.
|
||||||
|
pub fn convert_csc_dense<T>(csc: &CscMatrix<T>) -> DMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + Zero,
|
||||||
|
{
|
||||||
|
let mut output = DMatrix::zeros(csc.nrows(), csc.ncols());
|
||||||
|
|
||||||
|
for (i, j, v) in csc.triplet_iter() {
|
||||||
|
output[(i, j)] += v.inlined_clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a dense matrix to a [`CscMatrix`].
|
||||||
|
pub fn convert_dense_csc<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CscMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero,
|
||||||
|
R: Dim,
|
||||||
|
C: Dim,
|
||||||
|
S: Storage<T, R, C>,
|
||||||
|
{
|
||||||
|
let mut col_offsets = Vec::with_capacity(dense.ncols() + 1);
|
||||||
|
let mut row_idx = Vec::new();
|
||||||
|
let mut values = Vec::new();
|
||||||
|
|
||||||
|
col_offsets.push(0);
|
||||||
|
for j in 0..dense.ncols() {
|
||||||
|
for i in 0..dense.nrows() {
|
||||||
|
let v = dense.index((i, j));
|
||||||
|
if v != &T::zero() {
|
||||||
|
row_idx.push(i);
|
||||||
|
values.push(v.inlined_clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
col_offsets.push(row_idx.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Consider circumventing the data validity check here
|
||||||
|
// (would require unsafe, should benchmark)
|
||||||
|
CscMatrix::try_from_csc_data(dense.nrows(), dense.ncols(), col_offsets, row_idx, values)
|
||||||
|
.expect("Internal error: Invalid CscMatrix format during dense-> CSC conversion")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a [`CsrMatrix`] to a [`CscMatrix`].
|
||||||
|
pub fn convert_csr_csc<T>(csr: &CsrMatrix<T>) -> CscMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar,
|
||||||
|
{
|
||||||
|
let (offsets, indices, values) = cs::transpose_cs(
|
||||||
|
csr.nrows(),
|
||||||
|
csr.ncols(),
|
||||||
|
csr.row_offsets(),
|
||||||
|
csr.col_indices(),
|
||||||
|
csr.values(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// TODO: Avoid data validity check?
|
||||||
|
CscMatrix::try_from_csc_data(csr.nrows(), csr.ncols(), offsets, indices, values)
|
||||||
|
.expect("Internal error: Invalid CSC data during CSR->CSC conversion")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a [`CscMatrix`] to a [`CsrMatrix`].
|
||||||
|
pub fn convert_csc_csr<T>(csc: &CscMatrix<T>) -> CsrMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar,
|
||||||
|
{
|
||||||
|
let (offsets, indices, values) = cs::transpose_cs(
|
||||||
|
csc.ncols(),
|
||||||
|
csc.nrows(),
|
||||||
|
csc.col_offsets(),
|
||||||
|
csc.row_indices(),
|
||||||
|
csc.values(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// TODO: Avoid data validity check?
|
||||||
|
CsrMatrix::try_from_csr_data(csc.nrows(), csc.ncols(), offsets, indices, values)
|
||||||
|
.expect("Internal error: Invalid CSR data during CSC->CSR conversion")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_coo_cs<T>(
|
||||||
|
major_dim: usize,
|
||||||
|
major_indices: &[usize],
|
||||||
|
minor_indices: &[usize],
|
||||||
|
values: &[T],
|
||||||
|
) -> (Vec<usize>, Vec<usize>, Vec<T>)
|
||||||
|
where
|
||||||
|
T: Scalar + Zero,
|
||||||
|
{
|
||||||
|
assert_eq!(major_indices.len(), minor_indices.len());
|
||||||
|
assert_eq!(minor_indices.len(), values.len());
|
||||||
|
let nnz = major_indices.len();
|
||||||
|
|
||||||
|
let (unsorted_major_offsets, unsorted_minor_idx, unsorted_vals) = {
|
||||||
|
let mut offsets = vec![0usize; major_dim + 1];
|
||||||
|
let mut minor_idx = vec![0usize; nnz];
|
||||||
|
let mut vals = vec![T::zero(); nnz];
|
||||||
|
coo_to_unsorted_cs(
|
||||||
|
&mut offsets,
|
||||||
|
&mut minor_idx,
|
||||||
|
&mut vals,
|
||||||
|
major_dim,
|
||||||
|
major_indices,
|
||||||
|
minor_indices,
|
||||||
|
values,
|
||||||
|
);
|
||||||
|
(offsets, minor_idx, vals)
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: If input is sorted and/or without duplicates, we can avoid additional allocations
|
||||||
|
// and work. Might want to take advantage of this.
|
||||||
|
|
||||||
|
// At this point, assembly is essentially complete. However, we must ensure
|
||||||
|
// that minor indices are sorted within each lane and without duplicates.
|
||||||
|
let mut sorted_major_offsets = Vec::new();
|
||||||
|
let mut sorted_minor_idx = Vec::new();
|
||||||
|
let mut sorted_vals = Vec::new();
|
||||||
|
|
||||||
|
sorted_major_offsets.push(0);
|
||||||
|
|
||||||
|
// We need some temporary storage when working with each lane. Since lanes often have a
|
||||||
|
// very small number of non-zero entries, we try to amortize allocations across
|
||||||
|
// lanes by reusing workspace vectors
|
||||||
|
let mut idx_workspace = Vec::new();
|
||||||
|
let mut perm_workspace = Vec::new();
|
||||||
|
let mut values_workspace = Vec::new();
|
||||||
|
|
||||||
|
for lane in 0..major_dim {
|
||||||
|
let begin = unsorted_major_offsets[lane];
|
||||||
|
let end = unsorted_major_offsets[lane + 1];
|
||||||
|
let count = end - begin;
|
||||||
|
let range = begin..end;
|
||||||
|
|
||||||
|
// Ensure that workspaces can hold enough data
|
||||||
|
perm_workspace.resize(count, 0);
|
||||||
|
idx_workspace.resize(count, 0);
|
||||||
|
values_workspace.resize(count, T::zero());
|
||||||
|
sort_lane(
|
||||||
|
&mut idx_workspace[..count],
|
||||||
|
&mut values_workspace[..count],
|
||||||
|
&unsorted_minor_idx[range.clone()],
|
||||||
|
&unsorted_vals[range.clone()],
|
||||||
|
&mut perm_workspace[..count],
|
||||||
|
);
|
||||||
|
|
||||||
|
let sorted_ja_current_len = sorted_minor_idx.len();
|
||||||
|
|
||||||
|
combine_duplicates(
|
||||||
|
|idx| sorted_minor_idx.push(idx),
|
||||||
|
|val| sorted_vals.push(val),
|
||||||
|
&idx_workspace[..count],
|
||||||
|
&values_workspace[..count],
|
||||||
|
&Add::add,
|
||||||
|
);
|
||||||
|
|
||||||
|
let new_col_count = sorted_minor_idx.len() - sorted_ja_current_len;
|
||||||
|
sorted_major_offsets.push(sorted_major_offsets.last().unwrap() + new_col_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
(sorted_major_offsets, sorted_minor_idx, sorted_vals)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts matrix data given in triplet format to unsorted CSR/CSC, retaining any duplicated
|
||||||
|
/// indices.
|
||||||
|
///
|
||||||
|
/// Here `major/minor` is `row/col` for CSR and `col/row` for CSC.
|
||||||
|
fn coo_to_unsorted_cs<T: Clone>(
|
||||||
|
major_offsets: &mut [usize],
|
||||||
|
cs_minor_idx: &mut [usize],
|
||||||
|
cs_values: &mut [T],
|
||||||
|
major_dim: usize,
|
||||||
|
major_indices: &[usize],
|
||||||
|
minor_indices: &[usize],
|
||||||
|
coo_values: &[T],
|
||||||
|
) {
|
||||||
|
assert_eq!(major_offsets.len(), major_dim + 1);
|
||||||
|
assert_eq!(cs_minor_idx.len(), cs_values.len());
|
||||||
|
assert_eq!(cs_values.len(), major_indices.len());
|
||||||
|
assert_eq!(major_indices.len(), minor_indices.len());
|
||||||
|
assert_eq!(minor_indices.len(), coo_values.len());
|
||||||
|
|
||||||
|
// Count the number of occurrences of each row
|
||||||
|
for major_idx in major_indices {
|
||||||
|
major_offsets[*major_idx] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
cs::convert_counts_to_offsets(major_offsets);
|
||||||
|
|
||||||
|
{
|
||||||
|
// TODO: Instead of allocating a whole new vector storing the current counts,
|
||||||
|
// I think it's possible to be a bit more clever by storing each count
|
||||||
|
// in the last of the column indices for each row
|
||||||
|
let mut current_counts = vec![0usize; major_dim + 1];
|
||||||
|
let triplet_iter = major_indices.iter().zip(minor_indices).zip(coo_values);
|
||||||
|
for ((i, j), value) in triplet_iter {
|
||||||
|
let current_offset = major_offsets[*i] + current_counts[*i];
|
||||||
|
cs_minor_idx[current_offset] = *j;
|
||||||
|
cs_values[current_offset] = value.clone();
|
||||||
|
current_counts[*i] += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sort the indices of the given lane.
|
||||||
|
///
|
||||||
|
/// The indices and values in `minor_idx` and `values` are sorted according to the
|
||||||
|
/// minor indices and stored in `minor_idx_result` and `values_result` respectively.
|
||||||
|
///
|
||||||
|
/// All input slices are expected to be of the same length. The contents of mutable slices
|
||||||
|
/// can be arbitrary, as they are anyway overwritten.
|
||||||
|
fn sort_lane<T: Clone>(
|
||||||
|
minor_idx_result: &mut [usize],
|
||||||
|
values_result: &mut [T],
|
||||||
|
minor_idx: &[usize],
|
||||||
|
values: &[T],
|
||||||
|
workspace: &mut [usize],
|
||||||
|
) {
|
||||||
|
assert_eq!(minor_idx_result.len(), values_result.len());
|
||||||
|
assert_eq!(values_result.len(), minor_idx.len());
|
||||||
|
assert_eq!(minor_idx.len(), values.len());
|
||||||
|
assert_eq!(values.len(), workspace.len());
|
||||||
|
|
||||||
|
let permutation = workspace;
|
||||||
|
// Set permutation to identity
|
||||||
|
for (i, p) in permutation.iter_mut().enumerate() {
|
||||||
|
*p = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute permutation needed to bring minor indices into sorted order
|
||||||
|
// Note: Using sort_unstable here avoids internal allocations, which is crucial since
|
||||||
|
// each lane might have a small number of elements
|
||||||
|
permutation.sort_unstable_by_key(|idx| minor_idx[*idx]);
|
||||||
|
|
||||||
|
apply_permutation(minor_idx_result, minor_idx, permutation);
|
||||||
|
apply_permutation(values_result, values, permutation);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Move this into `utils` or something?
|
||||||
|
fn apply_permutation<T: Clone>(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) {
|
||||||
|
assert_eq!(out_slice.len(), in_slice.len());
|
||||||
|
assert_eq!(out_slice.len(), permutation.len());
|
||||||
|
for (out_element, old_pos) in out_slice.iter_mut().zip(permutation) {
|
||||||
|
*out_element = in_slice[*old_pos].clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Given *sorted* indices and corresponding scalar values, combines duplicates with the given
|
||||||
|
/// associative combiner and calls the provided produce methods with combined indices and values.
|
||||||
|
fn combine_duplicates<T: Clone>(
|
||||||
|
mut produce_idx: impl FnMut(usize),
|
||||||
|
mut produce_value: impl FnMut(T),
|
||||||
|
idx_array: &[usize],
|
||||||
|
values: &[T],
|
||||||
|
combiner: impl Fn(T, T) -> T,
|
||||||
|
) {
|
||||||
|
assert_eq!(idx_array.len(), values.len());
|
||||||
|
|
||||||
|
let mut i = 0;
|
||||||
|
while i < idx_array.len() {
|
||||||
|
let idx = idx_array[i];
|
||||||
|
let mut combined_value = values[i].clone();
|
||||||
|
let mut j = i + 1;
|
||||||
|
while j < idx_array.len() && idx_array[j] == idx {
|
||||||
|
let j_val = values[j].clone();
|
||||||
|
combined_value = combiner(combined_value, j_val);
|
||||||
|
j += 1;
|
||||||
|
}
|
||||||
|
produce_idx(idx);
|
||||||
|
produce_value(combined_value);
|
||||||
|
i = j;
|
||||||
|
}
|
||||||
|
}
|
208
nalgebra-sparse/src/coo.rs
Normal file
208
nalgebra-sparse/src/coo.rs
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
//! An implementation of the COO sparse matrix format.
|
||||||
|
|
||||||
|
use crate::SparseFormatError;
|
||||||
|
|
||||||
|
/// A COO representation of a sparse matrix.
|
||||||
|
///
|
||||||
|
/// A COO matrix stores entries in coordinate-form, that is triplets `(i, j, v)`, where `i` and `j`
|
||||||
|
/// correspond to row and column indices of the entry, and `v` to the value of the entry.
|
||||||
|
/// The format is of limited use for standard matrix operations. Its main purpose is to facilitate
|
||||||
|
/// easy construction of other, more efficient matrix formats (such as CSR/COO), and the
|
||||||
|
/// conversion between different formats.
|
||||||
|
///
|
||||||
|
/// # Format
|
||||||
|
///
|
||||||
|
/// For given dimensions `nrows` and `ncols`, the matrix is represented by three same-length
|
||||||
|
/// arrays `row_indices`, `col_indices` and `values` that constitute the coordinate triplets
|
||||||
|
/// of the matrix. The indices must be in bounds, but *duplicate entries are explicitly allowed*.
|
||||||
|
/// Upon conversion to other formats, the duplicate entries may be summed together. See the
|
||||||
|
/// documentation for the respective conversion functions.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix, csc::CscMatrix};
|
||||||
|
///
|
||||||
|
/// // Initialize a matrix with all zeros (no explicitly stored entries).
|
||||||
|
/// let mut coo = CooMatrix::new(4, 4);
|
||||||
|
/// // Or initialize it with a set of triplets
|
||||||
|
/// coo = CooMatrix::try_from_triplets(4, 4, vec![1, 2], vec![0, 1], vec![3.0, 4.0]).unwrap();
|
||||||
|
///
|
||||||
|
/// // Push a few triplets
|
||||||
|
/// coo.push(2, 0, 1.0);
|
||||||
|
/// coo.push(0, 1, 2.0);
|
||||||
|
///
|
||||||
|
/// // Convert to other matrix formats
|
||||||
|
/// let csr = CsrMatrix::from(&coo);
|
||||||
|
/// let csc = CscMatrix::from(&coo);
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct CooMatrix<T> {
|
||||||
|
nrows: usize,
|
||||||
|
ncols: usize,
|
||||||
|
row_indices: Vec<usize>,
|
||||||
|
col_indices: Vec<usize>,
|
||||||
|
values: Vec<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> CooMatrix<T> {
|
||||||
|
/// Construct a zero COO matrix of the given dimensions.
|
||||||
|
///
|
||||||
|
/// Specifically, the collection of triplets - corresponding to explicitly stored entries -
|
||||||
|
/// is empty, so that the matrix (implicitly) represented by the COO matrix consists of all
|
||||||
|
/// zero entries.
|
||||||
|
pub fn new(nrows: usize, ncols: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
nrows,
|
||||||
|
ncols,
|
||||||
|
row_indices: Vec::new(),
|
||||||
|
col_indices: Vec::new(),
|
||||||
|
values: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Construct a zero COO matrix of the given dimensions.
|
||||||
|
///
|
||||||
|
/// Specifically, the collection of triplets - corresponding to explicitly stored entries -
|
||||||
|
/// is empty, so that the matrix (implicitly) represented by the COO matrix consists of all
|
||||||
|
/// zero entries.
|
||||||
|
pub fn zeros(nrows: usize, ncols: usize) -> Self {
|
||||||
|
Self::new(nrows, ncols)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to construct a COO matrix from the given dimensions and a collection of
|
||||||
|
/// (i, j, v) triplets.
|
||||||
|
///
|
||||||
|
/// Returns an error if either row or column indices contain indices out of bounds,
|
||||||
|
/// or if the data arrays do not all have the same length. Note that the COO format
|
||||||
|
/// inherently supports duplicate entries.
|
||||||
|
pub fn try_from_triplets(
|
||||||
|
nrows: usize,
|
||||||
|
ncols: usize,
|
||||||
|
row_indices: Vec<usize>,
|
||||||
|
col_indices: Vec<usize>,
|
||||||
|
values: Vec<T>,
|
||||||
|
) -> Result<Self, SparseFormatError> {
|
||||||
|
use crate::SparseFormatErrorKind::*;
|
||||||
|
if row_indices.len() != col_indices.len() {
|
||||||
|
return Err(SparseFormatError::from_kind_and_msg(
|
||||||
|
InvalidStructure,
|
||||||
|
"Number of row and col indices must be the same.",
|
||||||
|
));
|
||||||
|
} else if col_indices.len() != values.len() {
|
||||||
|
return Err(SparseFormatError::from_kind_and_msg(
|
||||||
|
InvalidStructure,
|
||||||
|
"Number of col indices and values must be the same.",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let row_indices_in_bounds = row_indices.iter().all(|i| *i < nrows);
|
||||||
|
let col_indices_in_bounds = col_indices.iter().all(|j| *j < ncols);
|
||||||
|
|
||||||
|
if !row_indices_in_bounds {
|
||||||
|
Err(SparseFormatError::from_kind_and_msg(
|
||||||
|
IndexOutOfBounds,
|
||||||
|
"Row index out of bounds.",
|
||||||
|
))
|
||||||
|
} else if !col_indices_in_bounds {
|
||||||
|
Err(SparseFormatError::from_kind_and_msg(
|
||||||
|
IndexOutOfBounds,
|
||||||
|
"Col index out of bounds.",
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(Self {
|
||||||
|
nrows,
|
||||||
|
ncols,
|
||||||
|
row_indices,
|
||||||
|
col_indices,
|
||||||
|
values,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An iterator over triplets (i, j, v).
|
||||||
|
// TODO: Consider giving the iterator a concrete type instead of impl trait...?
|
||||||
|
pub fn triplet_iter(&self) -> impl Iterator<Item = (usize, usize, &T)> {
|
||||||
|
self.row_indices
|
||||||
|
.iter()
|
||||||
|
.zip(&self.col_indices)
|
||||||
|
.zip(&self.values)
|
||||||
|
.map(|((i, j), v)| (*i, *j, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Push a single triplet to the matrix.
|
||||||
|
///
|
||||||
|
/// This adds the value `v` to the `i`th row and `j`th column in the matrix.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
///
|
||||||
|
/// Panics if `i` or `j` is out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn push(&mut self, i: usize, j: usize, v: T) {
|
||||||
|
assert!(i < self.nrows);
|
||||||
|
assert!(j < self.ncols);
|
||||||
|
self.row_indices.push(i);
|
||||||
|
self.col_indices.push(j);
|
||||||
|
self.values.push(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of rows in the matrix.
|
||||||
|
#[inline]
|
||||||
|
pub fn nrows(&self) -> usize {
|
||||||
|
self.nrows
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of columns in the matrix.
|
||||||
|
#[inline]
|
||||||
|
pub fn ncols(&self) -> usize {
|
||||||
|
self.ncols
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of explicitly stored entries in the matrix.
|
||||||
|
///
|
||||||
|
/// This number *includes* duplicate entries. For example, if the `CooMatrix` contains duplicate
|
||||||
|
/// entries, then it may have a different number of non-zeros as reported by `nnz()` compared
|
||||||
|
/// to its CSR representation.
|
||||||
|
#[inline]
|
||||||
|
pub fn nnz(&self) -> usize {
|
||||||
|
self.values.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The row indices of the explicitly stored entries.
|
||||||
|
pub fn row_indices(&self) -> &[usize] {
|
||||||
|
&self.row_indices
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The column indices of the explicitly stored entries.
|
||||||
|
pub fn col_indices(&self) -> &[usize] {
|
||||||
|
&self.col_indices
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The values of the explicitly stored entries.
|
||||||
|
pub fn values(&self) -> &[T] {
|
||||||
|
&self.values
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Disassembles the matrix into individual triplet arrays.
|
||||||
|
///
|
||||||
|
/// Examples
|
||||||
|
/// --------
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # use nalgebra_sparse::coo::CooMatrix;
|
||||||
|
/// let row_indices = vec![0, 1];
|
||||||
|
/// let col_indices = vec![1, 2];
|
||||||
|
/// let values = vec![1.0, 2.0];
|
||||||
|
/// let coo = CooMatrix::try_from_triplets(2, 3, row_indices, col_indices, values)
|
||||||
|
/// .unwrap();
|
||||||
|
///
|
||||||
|
/// let (row_idx, col_idx, val) = coo.disassemble();
|
||||||
|
/// assert_eq!(row_idx, vec![0, 1]);
|
||||||
|
/// assert_eq!(col_idx, vec![1, 2]);
|
||||||
|
/// assert_eq!(val, vec![1.0, 2.0]);
|
||||||
|
/// ```
|
||||||
|
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
||||||
|
(self.row_indices, self.col_indices, self.values)
|
||||||
|
}
|
||||||
|
}
|
530
nalgebra-sparse/src/cs.rs
Normal file
530
nalgebra-sparse/src/cs.rs
Normal file
@ -0,0 +1,530 @@
|
|||||||
|
use std::mem::replace;
|
||||||
|
use std::ops::Range;
|
||||||
|
|
||||||
|
use num_traits::One;
|
||||||
|
|
||||||
|
use nalgebra::Scalar;
|
||||||
|
|
||||||
|
use crate::pattern::SparsityPattern;
|
||||||
|
use crate::{SparseEntry, SparseEntryMut};
|
||||||
|
|
||||||
|
/// An abstract compressed matrix.
|
||||||
|
///
|
||||||
|
/// For the time being, this is only used internally to share implementation between
|
||||||
|
/// CSR and CSC matrices.
|
||||||
|
///
|
||||||
|
/// A CSR matrix is obtained by associating rows with the major dimension, while a CSC matrix
|
||||||
|
/// is obtained by associating columns with the major dimension.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct CsMatrix<T> {
|
||||||
|
sparsity_pattern: SparsityPattern,
|
||||||
|
values: Vec<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> CsMatrix<T> {
|
||||||
|
/// Create a zero matrix with no explicitly stored entries.
|
||||||
|
#[inline]
|
||||||
|
pub fn new(major_dim: usize, minor_dim: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
sparsity_pattern: SparsityPattern::zeros(major_dim, minor_dim),
|
||||||
|
values: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn pattern(&self) -> &SparsityPattern {
|
||||||
|
&self.sparsity_pattern
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn values(&self) -> &[T] {
|
||||||
|
&self.values
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
|
&mut self.values
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
||||||
|
#[inline]
|
||||||
|
pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) {
|
||||||
|
let pattern = self.pattern();
|
||||||
|
(
|
||||||
|
pattern.major_offsets(),
|
||||||
|
pattern.minor_indices(),
|
||||||
|
&self.values,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
||||||
|
#[inline]
|
||||||
|
pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||||
|
let pattern = &mut self.sparsity_pattern;
|
||||||
|
(
|
||||||
|
pattern.major_offsets(),
|
||||||
|
pattern.minor_indices(),
|
||||||
|
&mut self.values,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
|
||||||
|
(&self.sparsity_pattern, &mut self.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>) -> Self {
|
||||||
|
assert_eq!(
|
||||||
|
pattern.nnz(),
|
||||||
|
values.len(),
|
||||||
|
"Internal error: consumers should verify shape compatibility."
|
||||||
|
);
|
||||||
|
Self {
|
||||||
|
sparsity_pattern: pattern,
|
||||||
|
values,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Internal method for simplifying access to a lane's data
|
||||||
|
#[inline]
|
||||||
|
pub fn get_index_range(&self, row_index: usize) -> Option<Range<usize>> {
|
||||||
|
let row_begin = *self.sparsity_pattern.major_offsets().get(row_index)?;
|
||||||
|
let row_end = *self.sparsity_pattern.major_offsets().get(row_index + 1)?;
|
||||||
|
Some(row_begin..row_end)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn take_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
|
||||||
|
(self.sparsity_pattern, self.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
||||||
|
let (offsets, indices) = self.sparsity_pattern.disassemble();
|
||||||
|
(offsets, indices, self.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn into_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
|
||||||
|
(self.sparsity_pattern, self.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an entry for the given major/minor indices, or `None` if the indices are out
|
||||||
|
/// of bounds.
|
||||||
|
pub fn get_entry(&self, major_index: usize, minor_index: usize) -> Option<SparseEntry<T>> {
|
||||||
|
let row_range = self.get_index_range(major_index)?;
|
||||||
|
let (_, minor_indices, values) = self.cs_data();
|
||||||
|
let minor_indices = &minor_indices[row_range.clone()];
|
||||||
|
let values = &values[row_range];
|
||||||
|
get_entry_from_slices(
|
||||||
|
self.pattern().minor_dim(),
|
||||||
|
minor_indices,
|
||||||
|
values,
|
||||||
|
minor_index,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable entry for the given major/minor indices, or `None` if the indices are out
|
||||||
|
/// of bounds.
|
||||||
|
pub fn get_entry_mut(
|
||||||
|
&mut self,
|
||||||
|
major_index: usize,
|
||||||
|
minor_index: usize,
|
||||||
|
) -> Option<SparseEntryMut<T>> {
|
||||||
|
let row_range = self.get_index_range(major_index)?;
|
||||||
|
let minor_dim = self.pattern().minor_dim();
|
||||||
|
let (_, minor_indices, values) = self.cs_data_mut();
|
||||||
|
let minor_indices = &minor_indices[row_range.clone()];
|
||||||
|
let values = &mut values[row_range];
|
||||||
|
get_mut_entry_from_slices(minor_dim, minor_indices, values, minor_index)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_lane(&self, index: usize) -> Option<CsLane<T>> {
|
||||||
|
let range = self.get_index_range(index)?;
|
||||||
|
let (_, minor_indices, values) = self.cs_data();
|
||||||
|
Some(CsLane {
|
||||||
|
minor_indices: &minor_indices[range.clone()],
|
||||||
|
values: &values[range],
|
||||||
|
minor_dim: self.pattern().minor_dim(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn get_lane_mut(&mut self, index: usize) -> Option<CsLaneMut<T>> {
|
||||||
|
let range = self.get_index_range(index)?;
|
||||||
|
let minor_dim = self.pattern().minor_dim();
|
||||||
|
let (_, minor_indices, values) = self.cs_data_mut();
|
||||||
|
Some(CsLaneMut {
|
||||||
|
minor_dim,
|
||||||
|
minor_indices: &minor_indices[range.clone()],
|
||||||
|
values: &mut values[range],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn lane_iter(&self) -> CsLaneIter<T> {
|
||||||
|
CsLaneIter::new(self.pattern(), self.values())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn lane_iter_mut(&mut self) -> CsLaneIterMut<T> {
|
||||||
|
CsLaneIterMut::new(&self.sparsity_pattern, &mut self.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn filter<P>(&self, predicate: P) -> Self
|
||||||
|
where
|
||||||
|
T: Clone,
|
||||||
|
P: Fn(usize, usize, &T) -> bool,
|
||||||
|
{
|
||||||
|
let (major_dim, minor_dim) = (self.pattern().major_dim(), self.pattern().minor_dim());
|
||||||
|
let mut new_offsets = Vec::with_capacity(self.pattern().major_dim() + 1);
|
||||||
|
let mut new_indices = Vec::new();
|
||||||
|
let mut new_values = Vec::new();
|
||||||
|
|
||||||
|
new_offsets.push(0);
|
||||||
|
for (i, lane) in self.lane_iter().enumerate() {
|
||||||
|
for (&j, value) in lane.minor_indices().iter().zip(lane.values) {
|
||||||
|
if predicate(i, j, value) {
|
||||||
|
new_indices.push(j);
|
||||||
|
new_values.push(value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
new_offsets.push(new_indices.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Avoid checks here
|
||||||
|
let new_pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||||
|
major_dim,
|
||||||
|
minor_dim,
|
||||||
|
new_offsets,
|
||||||
|
new_indices,
|
||||||
|
)
|
||||||
|
.expect("Internal error: Sparsity pattern must always be valid.");
|
||||||
|
|
||||||
|
Self::from_pattern_and_values(new_pattern, new_values)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the diagonal of the matrix as a sparse matrix.
|
||||||
|
pub fn diagonal_as_matrix(&self) -> Self
|
||||||
|
where
|
||||||
|
T: Clone,
|
||||||
|
{
|
||||||
|
// TODO: This might be faster with a binary search for each diagonal entry
|
||||||
|
self.filter(|i, j, _| i == j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Scalar + One> CsMatrix<T> {
|
||||||
|
#[inline]
|
||||||
|
pub fn identity(n: usize) -> Self {
|
||||||
|
let offsets: Vec<_> = (0..=n).collect();
|
||||||
|
let indices: Vec<_> = (0..n).collect();
|
||||||
|
let values = vec![T::one(); n];
|
||||||
|
|
||||||
|
// TODO: We should skip checks here
|
||||||
|
let pattern =
|
||||||
|
SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices).unwrap();
|
||||||
|
Self::from_pattern_and_values(pattern, values)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_entry_from_slices<'a, T>(
|
||||||
|
minor_dim: usize,
|
||||||
|
minor_indices: &'a [usize],
|
||||||
|
values: &'a [T],
|
||||||
|
global_minor_index: usize,
|
||||||
|
) -> Option<SparseEntry<'a, T>> {
|
||||||
|
let local_index = minor_indices.binary_search(&global_minor_index);
|
||||||
|
if let Ok(local_index) = local_index {
|
||||||
|
Some(SparseEntry::NonZero(&values[local_index]))
|
||||||
|
} else if global_minor_index < minor_dim {
|
||||||
|
Some(SparseEntry::Zero)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_mut_entry_from_slices<'a, T>(
|
||||||
|
minor_dim: usize,
|
||||||
|
minor_indices: &'a [usize],
|
||||||
|
values: &'a mut [T],
|
||||||
|
global_minor_indices: usize,
|
||||||
|
) -> Option<SparseEntryMut<'a, T>> {
|
||||||
|
let local_index = minor_indices.binary_search(&global_minor_indices);
|
||||||
|
if let Ok(local_index) = local_index {
|
||||||
|
Some(SparseEntryMut::NonZero(&mut values[local_index]))
|
||||||
|
} else if global_minor_indices < minor_dim {
|
||||||
|
Some(SparseEntryMut::Zero)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct CsLane<'a, T> {
|
||||||
|
minor_dim: usize,
|
||||||
|
minor_indices: &'a [usize],
|
||||||
|
values: &'a [T],
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub struct CsLaneMut<'a, T> {
|
||||||
|
minor_dim: usize,
|
||||||
|
minor_indices: &'a [usize],
|
||||||
|
values: &'a mut [T],
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CsLaneIter<'a, T> {
|
||||||
|
// The index of the lane that will be returned on the next iteration
|
||||||
|
current_lane_idx: usize,
|
||||||
|
pattern: &'a SparsityPattern,
|
||||||
|
remaining_values: &'a [T],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> CsLaneIter<'a, T> {
|
||||||
|
pub fn new(pattern: &'a SparsityPattern, values: &'a [T]) -> Self {
|
||||||
|
Self {
|
||||||
|
current_lane_idx: 0,
|
||||||
|
pattern,
|
||||||
|
remaining_values: values,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CsLaneIter<'a, T>
|
||||||
|
where
|
||||||
|
T: 'a,
|
||||||
|
{
|
||||||
|
type Item = CsLane<'a, T>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let lane = self.pattern.get_lane(self.current_lane_idx);
|
||||||
|
let minor_dim = self.pattern.minor_dim();
|
||||||
|
|
||||||
|
if let Some(minor_indices) = lane {
|
||||||
|
let count = minor_indices.len();
|
||||||
|
let values_in_lane = &self.remaining_values[..count];
|
||||||
|
self.remaining_values = &self.remaining_values[count..];
|
||||||
|
self.current_lane_idx += 1;
|
||||||
|
|
||||||
|
Some(CsLane {
|
||||||
|
minor_dim,
|
||||||
|
minor_indices,
|
||||||
|
values: values_in_lane,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CsLaneIterMut<'a, T> {
|
||||||
|
// The index of the lane that will be returned on the next iteration
|
||||||
|
current_lane_idx: usize,
|
||||||
|
pattern: &'a SparsityPattern,
|
||||||
|
remaining_values: &'a mut [T],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> CsLaneIterMut<'a, T> {
|
||||||
|
pub fn new(pattern: &'a SparsityPattern, values: &'a mut [T]) -> Self {
|
||||||
|
Self {
|
||||||
|
current_lane_idx: 0,
|
||||||
|
pattern,
|
||||||
|
remaining_values: values,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CsLaneIterMut<'a, T>
|
||||||
|
where
|
||||||
|
T: 'a,
|
||||||
|
{
|
||||||
|
type Item = CsLaneMut<'a, T>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let lane = self.pattern.get_lane(self.current_lane_idx);
|
||||||
|
let minor_dim = self.pattern.minor_dim();
|
||||||
|
|
||||||
|
if let Some(minor_indices) = lane {
|
||||||
|
let count = minor_indices.len();
|
||||||
|
|
||||||
|
let remaining = replace(&mut self.remaining_values, &mut []);
|
||||||
|
let (values_in_lane, remaining) = remaining.split_at_mut(count);
|
||||||
|
self.remaining_values = remaining;
|
||||||
|
self.current_lane_idx += 1;
|
||||||
|
|
||||||
|
Some(CsLaneMut {
|
||||||
|
minor_dim,
|
||||||
|
minor_indices,
|
||||||
|
values: values_in_lane,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implement the methods common to both CsLane and CsLaneMut. See the documentation for the
|
||||||
|
/// methods delegated here by CsrMatrix and CscMatrix members for more information.
|
||||||
|
macro_rules! impl_cs_lane_common_methods {
|
||||||
|
($name:ty) => {
|
||||||
|
impl<'a, T> $name {
|
||||||
|
#[inline]
|
||||||
|
pub fn minor_dim(&self) -> usize {
|
||||||
|
self.minor_dim
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn nnz(&self) -> usize {
|
||||||
|
self.minor_indices.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn minor_indices(&self) -> &[usize] {
|
||||||
|
self.minor_indices
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn values(&self) -> &[T] {
|
||||||
|
self.values
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
|
||||||
|
get_entry_from_slices(
|
||||||
|
self.minor_dim,
|
||||||
|
self.minor_indices,
|
||||||
|
self.values,
|
||||||
|
global_col_index,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_cs_lane_common_methods!(CsLane<'a, T>);
|
||||||
|
impl_cs_lane_common_methods!(CsLaneMut<'a, T>);
|
||||||
|
|
||||||
|
impl<'a, T> CsLaneMut<'a, T> {
|
||||||
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
|
self.values
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn indices_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
||||||
|
(self.minor_indices, self.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_entry_mut(&mut self, global_minor_index: usize) -> Option<SparseEntryMut<T>> {
|
||||||
|
get_mut_entry_from_slices(
|
||||||
|
self.minor_dim,
|
||||||
|
self.minor_indices,
|
||||||
|
self.values,
|
||||||
|
global_minor_index,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper struct for working with uninitialized data in vectors.
|
||||||
|
/// TODO: This doesn't belong here.
|
||||||
|
struct UninitVec<T> {
|
||||||
|
vec: Vec<T>,
|
||||||
|
len: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> UninitVec<T> {
|
||||||
|
pub fn from_len(len: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
vec: Vec::with_capacity(len),
|
||||||
|
// We need to store len separately, because for zero-sized types,
|
||||||
|
// Vec::with_capacity(len) does not give vec.capacity() == len
|
||||||
|
len,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the element associated with the given index to the provided value.
|
||||||
|
///
|
||||||
|
/// Must be called exactly once per index, otherwise results in undefined behavior.
|
||||||
|
pub unsafe fn set(&mut self, index: usize, value: T) {
|
||||||
|
self.vec.as_mut_ptr().add(index).write(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Marks the vector data as initialized by returning a full vector.
|
||||||
|
///
|
||||||
|
/// It is undefined behavior to call this function unless *all* elements have been written to
|
||||||
|
/// exactly once.
|
||||||
|
pub unsafe fn assume_init(mut self) -> Vec<T> {
|
||||||
|
self.vec.set_len(self.len);
|
||||||
|
self.vec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Transposes the compressed format.
|
||||||
|
///
|
||||||
|
/// This means that major and minor roles are switched. This is used for converting between CSR
|
||||||
|
/// and CSC formats.
|
||||||
|
pub fn transpose_cs<T>(
|
||||||
|
major_dim: usize,
|
||||||
|
minor_dim: usize,
|
||||||
|
source_major_offsets: &[usize],
|
||||||
|
source_minor_indices: &[usize],
|
||||||
|
values: &[T],
|
||||||
|
) -> (Vec<usize>, Vec<usize>, Vec<T>)
|
||||||
|
where
|
||||||
|
T: Scalar,
|
||||||
|
{
|
||||||
|
assert_eq!(source_major_offsets.len(), major_dim + 1);
|
||||||
|
assert_eq!(source_minor_indices.len(), values.len());
|
||||||
|
let nnz = values.len();
|
||||||
|
|
||||||
|
// Count the number of occurences of each minor index
|
||||||
|
let mut minor_counts = vec![0; minor_dim];
|
||||||
|
for minor_idx in source_minor_indices {
|
||||||
|
minor_counts[*minor_idx] += 1;
|
||||||
|
}
|
||||||
|
convert_counts_to_offsets(&mut minor_counts);
|
||||||
|
let mut target_offsets = minor_counts;
|
||||||
|
target_offsets.push(nnz);
|
||||||
|
let mut target_indices = vec![usize::MAX; nnz];
|
||||||
|
|
||||||
|
// We have to use uninitialized storage, because we don't have any kind of "default" value
|
||||||
|
// available for `T`. Unfortunately this necessitates some small amount of unsafe code
|
||||||
|
let mut target_values = UninitVec::from_len(nnz);
|
||||||
|
|
||||||
|
// Keep track of how many entries we have placed in each target major lane
|
||||||
|
let mut current_target_major_counts = vec![0; minor_dim];
|
||||||
|
|
||||||
|
for source_major_idx in 0..major_dim {
|
||||||
|
let source_lane_begin = source_major_offsets[source_major_idx];
|
||||||
|
let source_lane_end = source_major_offsets[source_major_idx + 1];
|
||||||
|
let source_lane_indices = &source_minor_indices[source_lane_begin..source_lane_end];
|
||||||
|
let source_lane_values = &values[source_lane_begin..source_lane_end];
|
||||||
|
|
||||||
|
for (&source_minor_idx, val) in source_lane_indices.iter().zip(source_lane_values) {
|
||||||
|
// Compute the offset in the target data for this particular source entry
|
||||||
|
let target_lane_count = &mut current_target_major_counts[source_minor_idx];
|
||||||
|
let entry_offset = target_offsets[source_minor_idx] + *target_lane_count;
|
||||||
|
target_indices[entry_offset] = source_major_idx;
|
||||||
|
unsafe {
|
||||||
|
target_values.set(entry_offset, val.inlined_clone());
|
||||||
|
}
|
||||||
|
*target_lane_count += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, we should have written to each element in target_values exactly once,
|
||||||
|
// so initialization should be sound
|
||||||
|
let target_values = unsafe { target_values.assume_init() };
|
||||||
|
(target_offsets, target_indices, target_values)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn convert_counts_to_offsets(counts: &mut [usize]) {
|
||||||
|
// Convert the counts to an offset
|
||||||
|
let mut offset = 0;
|
||||||
|
for i_offset in counts.iter_mut() {
|
||||||
|
let count = *i_offset;
|
||||||
|
*i_offset = offset;
|
||||||
|
offset += count;
|
||||||
|
}
|
||||||
|
}
|
704
nalgebra-sparse/src/csc.rs
Normal file
704
nalgebra-sparse/src/csc.rs
Normal file
@ -0,0 +1,704 @@
|
|||||||
|
//! An implementation of the CSC sparse matrix format.
|
||||||
|
//!
|
||||||
|
//! This is the module-level documentation. See [`CscMatrix`] for the main documentation of the
|
||||||
|
//! CSC implementation.
|
||||||
|
|
||||||
|
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||||
|
use crate::csr::CsrMatrix;
|
||||||
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
|
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||||
|
|
||||||
|
use nalgebra::Scalar;
|
||||||
|
use num_traits::One;
|
||||||
|
use std::slice::{Iter, IterMut};
|
||||||
|
|
||||||
|
/// A CSC representation of a sparse matrix.
|
||||||
|
///
|
||||||
|
/// The Compressed Sparse Column (CSC) format is well-suited as a general-purpose storage format
|
||||||
|
/// for many sparse matrix applications.
|
||||||
|
///
|
||||||
|
/// # Usage
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use nalgebra_sparse::csc::CscMatrix;
|
||||||
|
/// use nalgebra::{DMatrix, Matrix3x4};
|
||||||
|
/// use matrixcompare::assert_matrix_eq;
|
||||||
|
///
|
||||||
|
/// // The sparsity patterns of CSC matrices are immutable. This means that you cannot dynamically
|
||||||
|
/// // change the sparsity pattern of the matrix after it has been constructed. The easiest
|
||||||
|
/// // way to construct a CSC matrix is to first incrementally construct a COO matrix,
|
||||||
|
/// // and then convert it to CSC.
|
||||||
|
/// # use nalgebra_sparse::coo::CooMatrix;
|
||||||
|
/// # let coo = CooMatrix::<f64>::new(3, 3);
|
||||||
|
/// let csc = CscMatrix::from(&coo);
|
||||||
|
///
|
||||||
|
/// // Alternatively, a CSC matrix can be constructed directly from raw CSC data.
|
||||||
|
/// // Here, we construct a 3x4 matrix
|
||||||
|
/// let col_offsets = vec![0, 1, 3, 4, 5];
|
||||||
|
/// let row_indices = vec![0, 0, 2, 2, 0];
|
||||||
|
/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||||
|
///
|
||||||
|
/// // The dense representation of the CSC data, for comparison
|
||||||
|
/// let dense = Matrix3x4::new(1.0, 2.0, 0.0, 5.0,
|
||||||
|
/// 0.0, 0.0, 0.0, 0.0,
|
||||||
|
/// 0.0, 3.0, 4.0, 0.0);
|
||||||
|
///
|
||||||
|
/// // The constructor validates the raw CSC data and returns an error if it is invalid.
|
||||||
|
/// let csc = CscMatrix::try_from_csc_data(3, 4, col_offsets, row_indices, values)
|
||||||
|
/// .expect("CSC data must conform to format specifications");
|
||||||
|
/// assert_matrix_eq!(csc, dense);
|
||||||
|
///
|
||||||
|
/// // A third approach is to construct a CSC matrix from a pattern and values. Sometimes this is
|
||||||
|
/// // useful if the sparsity pattern is constructed separately from the values of the matrix.
|
||||||
|
/// let (pattern, values) = csc.into_pattern_and_values();
|
||||||
|
/// let csc = CscMatrix::try_from_pattern_and_values(pattern, values)
|
||||||
|
/// .expect("The pattern and values must be compatible");
|
||||||
|
///
|
||||||
|
/// // Once we have constructed our matrix, we can use it for arithmetic operations together with
|
||||||
|
/// // other CSC matrices and dense matrices/vectors.
|
||||||
|
/// let x = csc;
|
||||||
|
/// # #[allow(non_snake_case)]
|
||||||
|
/// let xTx = x.transpose() * &x;
|
||||||
|
/// let z = DMatrix::from_fn(4, 8, |i, j| (i as f64) * (j as f64));
|
||||||
|
/// let w = 3.0 * xTx * z;
|
||||||
|
///
|
||||||
|
/// // Although the sparsity pattern of a CSC matrix cannot be changed, its values can.
|
||||||
|
/// // Here are two different ways to scale all values by a constant:
|
||||||
|
/// let mut x = x;
|
||||||
|
/// x *= 5.0;
|
||||||
|
/// x.values_mut().iter_mut().for_each(|x_i| *x_i *= 5.0);
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// # Format
|
||||||
|
///
|
||||||
|
/// An `m x n` sparse matrix with `nnz` non-zeros in CSC format is represented by the
|
||||||
|
/// following three arrays:
|
||||||
|
///
|
||||||
|
/// - `col_offsets`, an array of integers with length `n + 1`.
|
||||||
|
/// - `row_indices`, an array of integers with length `nnz`.
|
||||||
|
/// - `values`, an array of values with length `nnz`.
|
||||||
|
///
|
||||||
|
/// The relationship between the arrays is described below.
|
||||||
|
///
|
||||||
|
/// - Each consecutive pair of entries `col_offsets[j] .. col_offsets[j + 1]` corresponds to an
|
||||||
|
/// offset range in `row_indices` that holds the row indices in column `j`.
|
||||||
|
/// - For an entry represented by the index `idx`, `row_indices[idx]` stores its column index and
|
||||||
|
/// `values[idx]` stores its value.
|
||||||
|
///
|
||||||
|
/// The following invariants must be upheld and are enforced by the data structure:
|
||||||
|
///
|
||||||
|
/// - `col_offsets[0] == 0`
|
||||||
|
/// - `col_offsets[m] == nnz`
|
||||||
|
/// - `col_offsets` is monotonically increasing.
|
||||||
|
/// - `0 <= row_indices[idx] < m` for all `idx < nnz`.
|
||||||
|
/// - The row indices associated with each column are monotonically increasing (see below).
|
||||||
|
///
|
||||||
|
/// The CSC format is a standard sparse matrix format (see [Wikipedia article]). The format
|
||||||
|
/// represents the matrix in a column-by-column fashion. The entries associated with column `j` are
|
||||||
|
/// determined as follows:
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # let col_offsets: Vec<usize> = vec![0, 0];
|
||||||
|
/// # let row_indices: Vec<usize> = vec![];
|
||||||
|
/// # let values: Vec<i32> = vec![];
|
||||||
|
/// # let j = 0;
|
||||||
|
/// let range = col_offsets[j] .. col_offsets[j + 1];
|
||||||
|
/// let col_j_rows = &row_indices[range.clone()];
|
||||||
|
/// let col_j_vals = &values[range];
|
||||||
|
///
|
||||||
|
/// // For each pair (i, v) in (col_j_rows, col_j_vals), we obtain a corresponding entry
|
||||||
|
/// // (i, j, v) in the matrix.
|
||||||
|
/// assert_eq!(col_j_rows.len(), col_j_vals.len());
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// In the above example, for each column `j`, the row indices `col_j_cols` must appear in
|
||||||
|
/// monotonically increasing order. In other words, they must be *sorted*. This criterion is not
|
||||||
|
/// standard among all sparse matrix libraries, but we enforce this property as it is a crucial
|
||||||
|
/// assumption for both correctness and performance for many algorithms.
|
||||||
|
///
|
||||||
|
/// Note that the CSR and CSC formats are essentially identical, except that CSC stores the matrix
|
||||||
|
/// column-by-column instead of row-by-row like CSR.
|
||||||
|
///
|
||||||
|
/// [Wikipedia article]: https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_column_(CSC_or_CCS)
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct CscMatrix<T> {
|
||||||
|
// Cols are major, rows are minor in the sparsity pattern
|
||||||
|
pub(crate) cs: CsMatrix<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> CscMatrix<T> {
|
||||||
|
/// Constructs a CSC representation of the (square) `n x n` identity matrix.
|
||||||
|
#[inline]
|
||||||
|
pub fn identity(n: usize) -> Self
|
||||||
|
where
|
||||||
|
T: Scalar + One,
|
||||||
|
{
|
||||||
|
Self {
|
||||||
|
cs: CsMatrix::identity(n),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a zero CSC matrix with no explicitly stored entries.
|
||||||
|
pub fn zeros(nrows: usize, ncols: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
cs: CsMatrix::new(ncols, nrows),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to construct a CSC matrix from raw CSC data.
|
||||||
|
///
|
||||||
|
/// It is assumed that each column contains unique and sorted row indices that are in
|
||||||
|
/// bounds with respect to the number of rows in the matrix. If this is not the case,
|
||||||
|
/// an error is returned to indicate the failure.
|
||||||
|
///
|
||||||
|
/// An error is returned if the data given does not conform to the CSC storage format.
|
||||||
|
/// See the documentation for [CscMatrix](struct.CscMatrix.html) for more information.
|
||||||
|
pub fn try_from_csc_data(
|
||||||
|
num_rows: usize,
|
||||||
|
num_cols: usize,
|
||||||
|
col_offsets: Vec<usize>,
|
||||||
|
row_indices: Vec<usize>,
|
||||||
|
values: Vec<T>,
|
||||||
|
) -> Result<Self, SparseFormatError> {
|
||||||
|
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||||
|
num_cols,
|
||||||
|
num_rows,
|
||||||
|
col_offsets,
|
||||||
|
row_indices,
|
||||||
|
)
|
||||||
|
.map_err(pattern_format_error_to_csc_error)?;
|
||||||
|
Self::try_from_pattern_and_values(pattern, values)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to construct a CSC matrix from a sparsity pattern and associated non-zero values.
|
||||||
|
///
|
||||||
|
/// Returns an error if the number of values does not match the number of minor indices
|
||||||
|
/// in the pattern.
|
||||||
|
pub fn try_from_pattern_and_values(
|
||||||
|
pattern: SparsityPattern,
|
||||||
|
values: Vec<T>,
|
||||||
|
) -> Result<Self, SparseFormatError> {
|
||||||
|
if pattern.nnz() == values.len() {
|
||||||
|
Ok(Self {
|
||||||
|
cs: CsMatrix::from_pattern_and_values(pattern, values),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Err(SparseFormatError::from_kind_and_msg(
|
||||||
|
SparseFormatErrorKind::InvalidStructure,
|
||||||
|
"Number of values and row indices must be the same",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of rows in the matrix.
|
||||||
|
#[inline]
|
||||||
|
pub fn nrows(&self) -> usize {
|
||||||
|
self.cs.pattern().minor_dim()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of columns in the matrix.
|
||||||
|
#[inline]
|
||||||
|
pub fn ncols(&self) -> usize {
|
||||||
|
self.cs.pattern().major_dim()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of non-zeros in the matrix.
|
||||||
|
///
|
||||||
|
/// Note that this corresponds to the number of explicitly stored entries, *not* the actual
|
||||||
|
/// number of algebraically zero entries in the matrix. Explicitly stored entries can still
|
||||||
|
/// be zero. Corresponds to the number of entries in the sparsity pattern.
|
||||||
|
#[inline]
|
||||||
|
pub fn nnz(&self) -> usize {
|
||||||
|
self.pattern().nnz()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The column offsets defining part of the CSC format.
|
||||||
|
#[inline]
|
||||||
|
pub fn col_offsets(&self) -> &[usize] {
|
||||||
|
self.pattern().major_offsets()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The row indices defining part of the CSC format.
|
||||||
|
#[inline]
|
||||||
|
pub fn row_indices(&self) -> &[usize] {
|
||||||
|
self.pattern().minor_indices()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The non-zero values defining part of the CSC format.
|
||||||
|
#[inline]
|
||||||
|
pub fn values(&self) -> &[T] {
|
||||||
|
self.cs.values()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mutable access to the non-zero values.
|
||||||
|
#[inline]
|
||||||
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
|
self.cs.values_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An iterator over non-zero triplets (i, j, v).
|
||||||
|
///
|
||||||
|
/// The iteration happens in column-major fashion, meaning that j increases monotonically,
|
||||||
|
/// and i increases monotonically within each row.
|
||||||
|
///
|
||||||
|
/// Examples
|
||||||
|
/// --------
|
||||||
|
/// ```
|
||||||
|
/// # use nalgebra_sparse::csc::CscMatrix;
|
||||||
|
/// let col_offsets = vec![0, 2, 3, 4];
|
||||||
|
/// let row_indices = vec![0, 2, 1, 0];
|
||||||
|
/// let values = vec![1, 3, 2, 4];
|
||||||
|
/// let mut csc = CscMatrix::try_from_csc_data(4, 3, col_offsets, row_indices, values)
|
||||||
|
/// .unwrap();
|
||||||
|
///
|
||||||
|
/// let triplets: Vec<_> = csc.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
|
||||||
|
/// assert_eq!(triplets, vec![(0, 0, 1), (2, 0, 3), (1, 1, 2), (0, 2, 4)]);
|
||||||
|
/// ```
|
||||||
|
pub fn triplet_iter(&self) -> CscTripletIter<T> {
|
||||||
|
CscTripletIter {
|
||||||
|
pattern_iter: self.pattern().entries(),
|
||||||
|
values_iter: self.values().iter(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A mutable iterator over non-zero triplets (i, j, v).
|
||||||
|
///
|
||||||
|
/// Iteration happens in the same order as for [triplet_iter](#method.triplet_iter).
|
||||||
|
///
|
||||||
|
/// Examples
|
||||||
|
/// --------
|
||||||
|
/// ```
|
||||||
|
/// # use nalgebra_sparse::csc::CscMatrix;
|
||||||
|
/// let col_offsets = vec![0, 2, 3, 4];
|
||||||
|
/// let row_indices = vec![0, 2, 1, 0];
|
||||||
|
/// let values = vec![1, 3, 2, 4];
|
||||||
|
/// // Using the same data as in the `triplet_iter` example
|
||||||
|
/// let mut csc = CscMatrix::try_from_csc_data(4, 3, col_offsets, row_indices, values)
|
||||||
|
/// .unwrap();
|
||||||
|
///
|
||||||
|
/// // Zero out lower-triangular terms
|
||||||
|
/// csc.triplet_iter_mut()
|
||||||
|
/// .filter(|(i, j, _)| j < i)
|
||||||
|
/// .for_each(|(_, _, v)| *v = 0);
|
||||||
|
///
|
||||||
|
/// let triplets: Vec<_> = csc.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
|
||||||
|
/// assert_eq!(triplets, vec![(0, 0, 1), (2, 0, 0), (1, 1, 2), (0, 2, 4)]);
|
||||||
|
/// ```
|
||||||
|
pub fn triplet_iter_mut(&mut self) -> CscTripletIterMut<T> {
|
||||||
|
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||||
|
CscTripletIterMut {
|
||||||
|
pattern_iter: pattern.entries(),
|
||||||
|
values_mut_iter: values.iter_mut(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the column at the given column index.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
/// Panics if column index is out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn col(&self, index: usize) -> CscCol<T> {
|
||||||
|
self.get_col(index).expect("Row index must be in bounds")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mutable column access for the given column index.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
/// Panics if column index is out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn col_mut(&mut self, index: usize) -> CscColMut<T> {
|
||||||
|
self.get_col_mut(index)
|
||||||
|
.expect("Row index must be in bounds")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the column at the given column index, or `None` if out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn get_col(&self, index: usize) -> Option<CscCol<T>> {
|
||||||
|
self.cs.get_lane(index).map(|lane| CscCol { lane })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mutable column access for the given column index, or `None` if out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn get_col_mut(&mut self, index: usize) -> Option<CscColMut<T>> {
|
||||||
|
self.cs.get_lane_mut(index).map(|lane| CscColMut { lane })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An iterator over columns in the matrix.
|
||||||
|
pub fn col_iter(&self) -> CscColIter<T> {
|
||||||
|
CscColIter {
|
||||||
|
lane_iter: CsLaneIter::new(self.pattern(), self.values()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A mutable iterator over columns in the matrix.
|
||||||
|
pub fn col_iter_mut(&mut self) -> CscColIterMut<T> {
|
||||||
|
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||||
|
CscColIterMut {
|
||||||
|
lane_iter: CsLaneIterMut::new(pattern, values),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Disassembles the CSC matrix into its underlying offset, index and value arrays.
|
||||||
|
///
|
||||||
|
/// If the matrix contains the sole reference to the sparsity pattern,
|
||||||
|
/// then the data is returned as-is. Otherwise, the sparsity pattern is cloned.
|
||||||
|
///
|
||||||
|
/// Examples
|
||||||
|
/// --------
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # use nalgebra_sparse::csc::CscMatrix;
|
||||||
|
/// let col_offsets = vec![0, 2, 3, 4];
|
||||||
|
/// let row_indices = vec![0, 2, 1, 0];
|
||||||
|
/// let values = vec![1, 3, 2, 4];
|
||||||
|
/// let mut csc = CscMatrix::try_from_csc_data(
|
||||||
|
/// 4,
|
||||||
|
/// 3,
|
||||||
|
/// col_offsets.clone(),
|
||||||
|
/// row_indices.clone(),
|
||||||
|
/// values.clone())
|
||||||
|
/// .unwrap();
|
||||||
|
/// let (col_offsets2, row_indices2, values2) = csc.disassemble();
|
||||||
|
/// assert_eq!(col_offsets2, col_offsets);
|
||||||
|
/// assert_eq!(row_indices2, row_indices);
|
||||||
|
/// assert_eq!(values2, values);
|
||||||
|
/// ```
|
||||||
|
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
||||||
|
self.cs.disassemble()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the sparsity pattern and values associated with this matrix.
|
||||||
|
pub fn into_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
|
||||||
|
self.cs.into_pattern_and_values()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a reference to the sparsity pattern and a mutable reference to the values.
|
||||||
|
#[inline]
|
||||||
|
pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
|
||||||
|
self.cs.pattern_and_values_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a reference to the underlying sparsity pattern.
|
||||||
|
pub fn pattern(&self) -> &SparsityPattern {
|
||||||
|
self.cs.pattern()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reinterprets the CSC matrix as its transpose represented by a CSR matrix.
|
||||||
|
///
|
||||||
|
/// This operation does not touch the CSC data, and is effectively a no-op.
|
||||||
|
pub fn transpose_as_csr(self) -> CsrMatrix<T> {
|
||||||
|
let (pattern, values) = self.cs.take_pattern_and_values();
|
||||||
|
CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an entry for the given row/col indices, or `None` if the indices are out of bounds.
|
||||||
|
///
|
||||||
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
|
/// stored row entries for the given column.
|
||||||
|
pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<T>> {
|
||||||
|
self.cs.get_entry(col_index, row_index)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable entry for the given row/col indices, or `None` if the indices are out
|
||||||
|
/// of bounds.
|
||||||
|
///
|
||||||
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
|
/// stored row entries for the given column.
|
||||||
|
pub fn get_entry_mut(
|
||||||
|
&mut self,
|
||||||
|
row_index: usize,
|
||||||
|
col_index: usize,
|
||||||
|
) -> Option<SparseEntryMut<T>> {
|
||||||
|
self.cs.get_entry_mut(col_index, row_index)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an entry for the given row/col indices.
|
||||||
|
///
|
||||||
|
/// Same as `get_entry`, except that it directly panics upon encountering row/col indices
|
||||||
|
/// out of bounds.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
/// Panics if `row_index` or `col_index` is out of bounds.
|
||||||
|
pub fn index_entry(&self, row_index: usize, col_index: usize) -> SparseEntry<T> {
|
||||||
|
self.get_entry(row_index, col_index)
|
||||||
|
.expect("Out of bounds matrix indices encountered")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable entry for the given row/col indices.
|
||||||
|
///
|
||||||
|
/// Same as `get_entry_mut`, except that it directly panics upon encountering row/col indices
|
||||||
|
/// out of bounds.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
/// Panics if `row_index` or `col_index` is out of bounds.
|
||||||
|
pub fn index_entry_mut(&mut self, row_index: usize, col_index: usize) -> SparseEntryMut<T> {
|
||||||
|
self.get_entry_mut(row_index, col_index)
|
||||||
|
.expect("Out of bounds matrix indices encountered")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSC data.
|
||||||
|
pub fn csc_data(&self) -> (&[usize], &[usize], &[T]) {
|
||||||
|
self.cs.cs_data()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSC data,
|
||||||
|
/// where the `values` array is mutable.
|
||||||
|
pub fn csc_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||||
|
self.cs.cs_data_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a sparse matrix that contains only the explicit entries decided by the
|
||||||
|
/// given predicate.
|
||||||
|
pub fn filter<P>(&self, predicate: P) -> Self
|
||||||
|
where
|
||||||
|
T: Clone,
|
||||||
|
P: Fn(usize, usize, &T) -> bool,
|
||||||
|
{
|
||||||
|
// Note: Predicate uses (row, col, value), so we have to switch around since
|
||||||
|
// cs uses (major, minor, value)
|
||||||
|
Self {
|
||||||
|
cs: self
|
||||||
|
.cs
|
||||||
|
.filter(|col_idx, row_idx, v| predicate(row_idx, col_idx, v)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a new matrix representing the upper triangular part of this matrix.
|
||||||
|
///
|
||||||
|
/// The result includes the diagonal of the matrix.
|
||||||
|
pub fn upper_triangle(&self) -> Self
|
||||||
|
where
|
||||||
|
T: Clone,
|
||||||
|
{
|
||||||
|
self.filter(|i, j, _| i <= j)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a new matrix representing the lower triangular part of this matrix.
|
||||||
|
///
|
||||||
|
/// The result includes the diagonal of the matrix.
|
||||||
|
pub fn lower_triangle(&self) -> Self
|
||||||
|
where
|
||||||
|
T: Clone,
|
||||||
|
{
|
||||||
|
self.filter(|i, j, _| i >= j)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the diagonal of the matrix as a sparse matrix.
|
||||||
|
pub fn diagonal_as_csc(&self) -> Self
|
||||||
|
where
|
||||||
|
T: Clone,
|
||||||
|
{
|
||||||
|
Self {
|
||||||
|
cs: self.cs.diagonal_as_matrix(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute the transpose of the matrix.
|
||||||
|
pub fn transpose(&self) -> CscMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar,
|
||||||
|
{
|
||||||
|
CsrMatrix::from(self).transpose_as_csc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert pattern format errors into more meaningful CSC-specific errors.
|
||||||
|
///
|
||||||
|
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
||||||
|
/// not lanes, major and minor dimensions.
|
||||||
|
fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseFormatError {
|
||||||
|
use SparseFormatError as E;
|
||||||
|
use SparseFormatErrorKind as K;
|
||||||
|
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||||
|
use SparsityPatternFormatError::*;
|
||||||
|
|
||||||
|
match err {
|
||||||
|
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||||
|
K::InvalidStructure,
|
||||||
|
"Length of col offset array is not equal to ncols + 1.",
|
||||||
|
),
|
||||||
|
InvalidOffsetFirstLast => E::from_kind_and_msg(
|
||||||
|
K::InvalidStructure,
|
||||||
|
"First or last col offset is inconsistent with format specification.",
|
||||||
|
),
|
||||||
|
NonmonotonicOffsets => E::from_kind_and_msg(
|
||||||
|
K::InvalidStructure,
|
||||||
|
"Col offsets are not monotonically increasing.",
|
||||||
|
),
|
||||||
|
NonmonotonicMinorIndices => E::from_kind_and_msg(
|
||||||
|
K::InvalidStructure,
|
||||||
|
"Row indices are not monotonically increasing (sorted) within each column.",
|
||||||
|
),
|
||||||
|
MinorIndexOutOfBounds => {
|
||||||
|
E::from_kind_and_msg(K::IndexOutOfBounds, "Row indices are out of bounds.")
|
||||||
|
}
|
||||||
|
PatternDuplicateEntry => {
|
||||||
|
E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Iterator type for iterating over triplets in a CSC matrix.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CscTripletIter<'a, T> {
|
||||||
|
pattern_iter: SparsityPatternIter<'a>,
|
||||||
|
values_iter: Iter<'a, T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: Clone> CscTripletIter<'a, T> {
|
||||||
|
/// Adapts the triplet iterator to return owned values.
|
||||||
|
///
|
||||||
|
/// The triplet iterator returns references to the values. This method adapts the iterator
|
||||||
|
/// so that the values are cloned.
|
||||||
|
#[inline]
|
||||||
|
pub fn cloned_values(self) -> impl 'a + Iterator<Item = (usize, usize, T)> {
|
||||||
|
self.map(|(i, j, v)| (i, j, v.clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CscTripletIter<'a, T> {
|
||||||
|
type Item = (usize, usize, &'a T);
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let next_entry = self.pattern_iter.next();
|
||||||
|
let next_value = self.values_iter.next();
|
||||||
|
|
||||||
|
match (next_entry, next_value) {
|
||||||
|
(Some((i, j)), Some(v)) => Some((j, i, v)),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Iterator type for mutably iterating over triplets in a CSC matrix.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CscTripletIterMut<'a, T> {
|
||||||
|
pattern_iter: SparsityPatternIter<'a>,
|
||||||
|
values_mut_iter: IterMut<'a, T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
|
||||||
|
type Item = (usize, usize, &'a mut T);
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let next_entry = self.pattern_iter.next();
|
||||||
|
let next_value = self.values_mut_iter.next();
|
||||||
|
|
||||||
|
match (next_entry, next_value) {
|
||||||
|
(Some((i, j)), Some(v)) => Some((j, i, v)),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An immutable representation of a column in a CSC matrix.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct CscCol<'a, T> {
|
||||||
|
lane: CsLane<'a, T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A mutable representation of a column in a CSC matrix.
|
||||||
|
///
|
||||||
|
/// Note that only explicitly stored entries can be mutated. The sparsity pattern belonging
|
||||||
|
/// to the column cannot be modified.
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub struct CscColMut<'a, T> {
|
||||||
|
lane: CsLaneMut<'a, T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implement the methods common to both CscCol and CscColMut
|
||||||
|
macro_rules! impl_csc_col_common_methods {
|
||||||
|
($name:ty) => {
|
||||||
|
impl<'a, T> $name {
|
||||||
|
/// The number of global rows in the column.
|
||||||
|
#[inline]
|
||||||
|
pub fn nrows(&self) -> usize {
|
||||||
|
self.lane.minor_dim()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of non-zeros in this column.
|
||||||
|
#[inline]
|
||||||
|
pub fn nnz(&self) -> usize {
|
||||||
|
self.lane.nnz()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The row indices corresponding to explicitly stored entries in this column.
|
||||||
|
#[inline]
|
||||||
|
pub fn row_indices(&self) -> &[usize] {
|
||||||
|
self.lane.minor_indices()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The values corresponding to explicitly stored entries in this column.
|
||||||
|
#[inline]
|
||||||
|
pub fn values(&self) -> &[T] {
|
||||||
|
self.lane.values()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an entry for the given global row index.
|
||||||
|
///
|
||||||
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
|
/// stored row entries.
|
||||||
|
pub fn get_entry(&self, global_row_index: usize) -> Option<SparseEntry<T>> {
|
||||||
|
self.lane.get_entry(global_row_index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_csc_col_common_methods!(CscCol<'a, T>);
|
||||||
|
impl_csc_col_common_methods!(CscColMut<'a, T>);
|
||||||
|
|
||||||
|
impl<'a, T> CscColMut<'a, T> {
|
||||||
|
/// Mutable access to the values corresponding to explicitly stored entries in this column.
|
||||||
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
|
self.lane.values_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Provides simultaneous access to row indices and mutable values corresponding to the
|
||||||
|
/// explicitly stored entries in this column.
|
||||||
|
///
|
||||||
|
/// This method primarily facilitates low-level access for methods that process data stored
|
||||||
|
/// in CSC format directly.
|
||||||
|
pub fn rows_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
||||||
|
self.lane.indices_and_values_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable entry for the given global row index.
|
||||||
|
pub fn get_entry_mut(&mut self, global_row_index: usize) -> Option<SparseEntryMut<T>> {
|
||||||
|
self.lane.get_entry_mut(global_row_index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Column iterator for [CscMatrix](struct.CscMatrix.html).
|
||||||
|
pub struct CscColIter<'a, T> {
|
||||||
|
lane_iter: CsLaneIter<'a, T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CscColIter<'a, T> {
|
||||||
|
type Item = CscCol<'a, T>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
self.lane_iter.next().map(|lane| CscCol { lane })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mutable column iterator for [CscMatrix](struct.CscMatrix.html).
|
||||||
|
pub struct CscColIterMut<'a, T> {
|
||||||
|
lane_iter: CsLaneIterMut<'a, T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CscColIterMut<'a, T>
|
||||||
|
where
|
||||||
|
T: 'a,
|
||||||
|
{
|
||||||
|
type Item = CscColMut<'a, T>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
self.lane_iter.next().map(|lane| CscColMut { lane })
|
||||||
|
}
|
||||||
|
}
|
708
nalgebra-sparse/src/csr.rs
Normal file
708
nalgebra-sparse/src/csr.rs
Normal file
@ -0,0 +1,708 @@
|
|||||||
|
//! An implementation of the CSR sparse matrix format.
|
||||||
|
//!
|
||||||
|
//! This is the module-level documentation. See [`CsrMatrix`] for the main documentation of the
|
||||||
|
//! CSC implementation.
|
||||||
|
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||||
|
use crate::csc::CscMatrix;
|
||||||
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
|
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||||
|
|
||||||
|
use nalgebra::Scalar;
|
||||||
|
use num_traits::One;
|
||||||
|
|
||||||
|
use std::slice::{Iter, IterMut};
|
||||||
|
|
||||||
|
/// A CSR representation of a sparse matrix.
|
||||||
|
///
|
||||||
|
/// The Compressed Sparse Row (CSR) format is well-suited as a general-purpose storage format
|
||||||
|
/// for many sparse matrix applications.
|
||||||
|
///
|
||||||
|
/// # Usage
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
/// use nalgebra::{DMatrix, Matrix3x4};
|
||||||
|
/// use matrixcompare::assert_matrix_eq;
|
||||||
|
///
|
||||||
|
/// // The sparsity patterns of CSR matrices are immutable. This means that you cannot dynamically
|
||||||
|
/// // change the sparsity pattern of the matrix after it has been constructed. The easiest
|
||||||
|
/// // way to construct a CSR matrix is to first incrementally construct a COO matrix,
|
||||||
|
/// // and then convert it to CSR.
|
||||||
|
/// # use nalgebra_sparse::coo::CooMatrix;
|
||||||
|
/// # let coo = CooMatrix::<f64>::new(3, 3);
|
||||||
|
/// let csr = CsrMatrix::from(&coo);
|
||||||
|
///
|
||||||
|
/// // Alternatively, a CSR matrix can be constructed directly from raw CSR data.
|
||||||
|
/// // Here, we construct a 3x4 matrix
|
||||||
|
/// let row_offsets = vec![0, 3, 3, 5];
|
||||||
|
/// let col_indices = vec![0, 1, 3, 1, 2];
|
||||||
|
/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||||
|
///
|
||||||
|
/// // The dense representation of the CSR data, for comparison
|
||||||
|
/// let dense = Matrix3x4::new(1.0, 2.0, 0.0, 3.0,
|
||||||
|
/// 0.0, 0.0, 0.0, 0.0,
|
||||||
|
/// 0.0, 4.0, 5.0, 0.0);
|
||||||
|
///
|
||||||
|
/// // The constructor validates the raw CSR data and returns an error if it is invalid.
|
||||||
|
/// let csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
|
||||||
|
/// .expect("CSR data must conform to format specifications");
|
||||||
|
/// assert_matrix_eq!(csr, dense);
|
||||||
|
///
|
||||||
|
/// // A third approach is to construct a CSR matrix from a pattern and values. Sometimes this is
|
||||||
|
/// // useful if the sparsity pattern is constructed separately from the values of the matrix.
|
||||||
|
/// let (pattern, values) = csr.into_pattern_and_values();
|
||||||
|
/// let csr = CsrMatrix::try_from_pattern_and_values(pattern, values)
|
||||||
|
/// .expect("The pattern and values must be compatible");
|
||||||
|
///
|
||||||
|
/// // Once we have constructed our matrix, we can use it for arithmetic operations together with
|
||||||
|
/// // other CSR matrices and dense matrices/vectors.
|
||||||
|
/// let x = csr;
|
||||||
|
/// # #[allow(non_snake_case)]
|
||||||
|
/// let xTx = x.transpose() * &x;
|
||||||
|
/// let z = DMatrix::from_fn(4, 8, |i, j| (i as f64) * (j as f64));
|
||||||
|
/// let w = 3.0 * xTx * z;
|
||||||
|
///
|
||||||
|
/// // Although the sparsity pattern of a CSR matrix cannot be changed, its values can.
|
||||||
|
/// // Here are two different ways to scale all values by a constant:
|
||||||
|
/// let mut x = x;
|
||||||
|
/// x *= 5.0;
|
||||||
|
/// x.values_mut().iter_mut().for_each(|x_i| *x_i *= 5.0);
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// # Format
|
||||||
|
///
|
||||||
|
/// An `m x n` sparse matrix with `nnz` non-zeros in CSR format is represented by the
|
||||||
|
/// following three arrays:
|
||||||
|
///
|
||||||
|
/// - `row_offsets`, an array of integers with length `m + 1`.
|
||||||
|
/// - `col_indices`, an array of integers with length `nnz`.
|
||||||
|
/// - `values`, an array of values with length `nnz`.
|
||||||
|
///
|
||||||
|
/// The relationship between the arrays is described below.
|
||||||
|
///
|
||||||
|
/// - Each consecutive pair of entries `row_offsets[i] .. row_offsets[i + 1]` corresponds to an
|
||||||
|
/// offset range in `col_indices` that holds the column indices in row `i`.
|
||||||
|
/// - For an entry represented by the index `idx`, `col_indices[idx]` stores its column index and
|
||||||
|
/// `values[idx]` stores its value.
|
||||||
|
///
|
||||||
|
/// The following invariants must be upheld and are enforced by the data structure:
|
||||||
|
///
|
||||||
|
/// - `row_offsets[0] == 0`
|
||||||
|
/// - `row_offsets[m] == nnz`
|
||||||
|
/// - `row_offsets` is monotonically increasing.
|
||||||
|
/// - `0 <= col_indices[idx] < n` for all `idx < nnz`.
|
||||||
|
/// - The column indices associated with each row are monotonically increasing (see below).
|
||||||
|
///
|
||||||
|
/// The CSR format is a standard sparse matrix format (see [Wikipedia article]). The format
|
||||||
|
/// represents the matrix in a row-by-row fashion. The entries associated with row `i` are
|
||||||
|
/// determined as follows:
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # let row_offsets: Vec<usize> = vec![0, 0];
|
||||||
|
/// # let col_indices: Vec<usize> = vec![];
|
||||||
|
/// # let values: Vec<i32> = vec![];
|
||||||
|
/// # let i = 0;
|
||||||
|
/// let range = row_offsets[i] .. row_offsets[i + 1];
|
||||||
|
/// let row_i_cols = &col_indices[range.clone()];
|
||||||
|
/// let row_i_vals = &values[range];
|
||||||
|
///
|
||||||
|
/// // For each pair (j, v) in (row_i_cols, row_i_vals), we obtain a corresponding entry
|
||||||
|
/// // (i, j, v) in the matrix.
|
||||||
|
/// assert_eq!(row_i_cols.len(), row_i_vals.len());
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// In the above example, for each row `i`, the column indices `row_i_cols` must appear in
|
||||||
|
/// monotonically increasing order. In other words, they must be *sorted*. This criterion is not
|
||||||
|
/// standard among all sparse matrix libraries, but we enforce this property as it is a crucial
|
||||||
|
/// assumption for both correctness and performance for many algorithms.
|
||||||
|
///
|
||||||
|
/// Note that the CSR and CSC formats are essentially identical, except that CSC stores the matrix
|
||||||
|
/// column-by-column instead of row-by-row like CSR.
|
||||||
|
///
|
||||||
|
/// [Wikipedia article]: https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct CsrMatrix<T> {
|
||||||
|
// Rows are major, cols are minor in the sparsity pattern
|
||||||
|
pub(crate) cs: CsMatrix<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> CsrMatrix<T> {
|
||||||
|
/// Constructs a CSR representation of the (square) `n x n` identity matrix.
|
||||||
|
#[inline]
|
||||||
|
pub fn identity(n: usize) -> Self
|
||||||
|
where
|
||||||
|
T: Scalar + One,
|
||||||
|
{
|
||||||
|
Self {
|
||||||
|
cs: CsMatrix::identity(n),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a zero CSR matrix with no explicitly stored entries.
|
||||||
|
pub fn zeros(nrows: usize, ncols: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
cs: CsMatrix::new(nrows, ncols),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to construct a CSR matrix from raw CSR data.
|
||||||
|
///
|
||||||
|
/// It is assumed that each row contains unique and sorted column indices that are in
|
||||||
|
/// bounds with respect to the number of columns in the matrix. If this is not the case,
|
||||||
|
/// an error is returned to indicate the failure.
|
||||||
|
///
|
||||||
|
/// An error is returned if the data given does not conform to the CSR storage format.
|
||||||
|
/// See the documentation for [CsrMatrix](struct.CsrMatrix.html) for more information.
|
||||||
|
pub fn try_from_csr_data(
|
||||||
|
num_rows: usize,
|
||||||
|
num_cols: usize,
|
||||||
|
row_offsets: Vec<usize>,
|
||||||
|
col_indices: Vec<usize>,
|
||||||
|
values: Vec<T>,
|
||||||
|
) -> Result<Self, SparseFormatError> {
|
||||||
|
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||||
|
num_rows,
|
||||||
|
num_cols,
|
||||||
|
row_offsets,
|
||||||
|
col_indices,
|
||||||
|
)
|
||||||
|
.map_err(pattern_format_error_to_csr_error)?;
|
||||||
|
Self::try_from_pattern_and_values(pattern, values)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
|
||||||
|
///
|
||||||
|
/// Returns an error if the number of values does not match the number of minor indices
|
||||||
|
/// in the pattern.
|
||||||
|
pub fn try_from_pattern_and_values(
|
||||||
|
pattern: SparsityPattern,
|
||||||
|
values: Vec<T>,
|
||||||
|
) -> Result<Self, SparseFormatError> {
|
||||||
|
if pattern.nnz() == values.len() {
|
||||||
|
Ok(Self {
|
||||||
|
cs: CsMatrix::from_pattern_and_values(pattern, values),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Err(SparseFormatError::from_kind_and_msg(
|
||||||
|
SparseFormatErrorKind::InvalidStructure,
|
||||||
|
"Number of values and column indices must be the same",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of rows in the matrix.
|
||||||
|
#[inline]
|
||||||
|
pub fn nrows(&self) -> usize {
|
||||||
|
self.cs.pattern().major_dim()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of columns in the matrix.
|
||||||
|
#[inline]
|
||||||
|
pub fn ncols(&self) -> usize {
|
||||||
|
self.cs.pattern().minor_dim()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of non-zeros in the matrix.
|
||||||
|
///
|
||||||
|
/// Note that this corresponds to the number of explicitly stored entries, *not* the actual
|
||||||
|
/// number of algebraically zero entries in the matrix. Explicitly stored entries can still
|
||||||
|
/// be zero. Corresponds to the number of entries in the sparsity pattern.
|
||||||
|
#[inline]
|
||||||
|
pub fn nnz(&self) -> usize {
|
||||||
|
self.cs.pattern().nnz()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The row offsets defining part of the CSR format.
|
||||||
|
#[inline]
|
||||||
|
pub fn row_offsets(&self) -> &[usize] {
|
||||||
|
let (offsets, _, _) = self.cs.cs_data();
|
||||||
|
offsets
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The column indices defining part of the CSR format.
|
||||||
|
#[inline]
|
||||||
|
pub fn col_indices(&self) -> &[usize] {
|
||||||
|
let (_, indices, _) = self.cs.cs_data();
|
||||||
|
indices
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The non-zero values defining part of the CSR format.
|
||||||
|
#[inline]
|
||||||
|
pub fn values(&self) -> &[T] {
|
||||||
|
self.cs.values()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mutable access to the non-zero values.
|
||||||
|
#[inline]
|
||||||
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
|
self.cs.values_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An iterator over non-zero triplets (i, j, v).
|
||||||
|
///
|
||||||
|
/// The iteration happens in row-major fashion, meaning that i increases monotonically,
|
||||||
|
/// and j increases monotonically within each row.
|
||||||
|
///
|
||||||
|
/// Examples
|
||||||
|
/// --------
|
||||||
|
/// ```
|
||||||
|
/// # use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
/// let row_offsets = vec![0, 2, 3, 4];
|
||||||
|
/// let col_indices = vec![0, 2, 1, 0];
|
||||||
|
/// let values = vec![1, 2, 3, 4];
|
||||||
|
/// let mut csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
|
||||||
|
/// .unwrap();
|
||||||
|
///
|
||||||
|
/// let triplets: Vec<_> = csr.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
|
||||||
|
/// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 4)]);
|
||||||
|
/// ```
|
||||||
|
pub fn triplet_iter(&self) -> CsrTripletIter<T> {
|
||||||
|
CsrTripletIter {
|
||||||
|
pattern_iter: self.pattern().entries(),
|
||||||
|
values_iter: self.values().iter(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A mutable iterator over non-zero triplets (i, j, v).
|
||||||
|
///
|
||||||
|
/// Iteration happens in the same order as for [triplet_iter](#method.triplet_iter).
|
||||||
|
///
|
||||||
|
/// Examples
|
||||||
|
/// --------
|
||||||
|
/// ```
|
||||||
|
/// # use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
/// # let row_offsets = vec![0, 2, 3, 4];
|
||||||
|
/// # let col_indices = vec![0, 2, 1, 0];
|
||||||
|
/// # let values = vec![1, 2, 3, 4];
|
||||||
|
/// // Using the same data as in the `triplet_iter` example
|
||||||
|
/// let mut csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
|
||||||
|
/// .unwrap();
|
||||||
|
///
|
||||||
|
/// // Zero out lower-triangular terms
|
||||||
|
/// csr.triplet_iter_mut()
|
||||||
|
/// .filter(|(i, j, _)| j < i)
|
||||||
|
/// .for_each(|(_, _, v)| *v = 0);
|
||||||
|
///
|
||||||
|
/// let triplets: Vec<_> = csr.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
|
||||||
|
/// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 0)]);
|
||||||
|
/// ```
|
||||||
|
pub fn triplet_iter_mut(&mut self) -> CsrTripletIterMut<T> {
|
||||||
|
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||||
|
CsrTripletIterMut {
|
||||||
|
pattern_iter: pattern.entries(),
|
||||||
|
values_mut_iter: values.iter_mut(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the row at the given row index.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
/// Panics if row index is out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn row(&self, index: usize) -> CsrRow<T> {
|
||||||
|
self.get_row(index).expect("Row index must be in bounds")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mutable row access for the given row index.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
/// Panics if row index is out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn row_mut(&mut self, index: usize) -> CsrRowMut<T> {
|
||||||
|
self.get_row_mut(index)
|
||||||
|
.expect("Row index must be in bounds")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the row at the given row index, or `None` if out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn get_row(&self, index: usize) -> Option<CsrRow<T>> {
|
||||||
|
self.cs.get_lane(index).map(|lane| CsrRow { lane })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mutable row access for the given row index, or `None` if out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<T>> {
|
||||||
|
self.cs.get_lane_mut(index).map(|lane| CsrRowMut { lane })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An iterator over rows in the matrix.
|
||||||
|
pub fn row_iter(&self) -> CsrRowIter<T> {
|
||||||
|
CsrRowIter {
|
||||||
|
lane_iter: CsLaneIter::new(self.pattern(), self.values()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A mutable iterator over rows in the matrix.
|
||||||
|
pub fn row_iter_mut(&mut self) -> CsrRowIterMut<T> {
|
||||||
|
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||||
|
CsrRowIterMut {
|
||||||
|
lane_iter: CsLaneIterMut::new(pattern, values),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Disassembles the CSR matrix into its underlying offset, index and value arrays.
|
||||||
|
///
|
||||||
|
/// If the matrix contains the sole reference to the sparsity pattern,
|
||||||
|
/// then the data is returned as-is. Otherwise, the sparsity pattern is cloned.
|
||||||
|
///
|
||||||
|
/// Examples
|
||||||
|
/// --------
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
/// let row_offsets = vec![0, 2, 3, 4];
|
||||||
|
/// let col_indices = vec![0, 2, 1, 0];
|
||||||
|
/// let values = vec![1, 2, 3, 4];
|
||||||
|
/// let mut csr = CsrMatrix::try_from_csr_data(
|
||||||
|
/// 3,
|
||||||
|
/// 4,
|
||||||
|
/// row_offsets.clone(),
|
||||||
|
/// col_indices.clone(),
|
||||||
|
/// values.clone())
|
||||||
|
/// .unwrap();
|
||||||
|
/// let (row_offsets2, col_indices2, values2) = csr.disassemble();
|
||||||
|
/// assert_eq!(row_offsets2, row_offsets);
|
||||||
|
/// assert_eq!(col_indices2, col_indices);
|
||||||
|
/// assert_eq!(values2, values);
|
||||||
|
/// ```
|
||||||
|
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
||||||
|
self.cs.disassemble()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the sparsity pattern and values associated with this matrix.
|
||||||
|
pub fn into_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
|
||||||
|
self.cs.into_pattern_and_values()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a reference to the sparsity pattern and a mutable reference to the values.
|
||||||
|
#[inline]
|
||||||
|
pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
|
||||||
|
self.cs.pattern_and_values_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a reference to the underlying sparsity pattern.
|
||||||
|
pub fn pattern(&self) -> &SparsityPattern {
|
||||||
|
self.cs.pattern()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reinterprets the CSR matrix as its transpose represented by a CSC matrix.
|
||||||
|
///
|
||||||
|
/// This operation does not touch the CSR data, and is effectively a no-op.
|
||||||
|
pub fn transpose_as_csc(self) -> CscMatrix<T> {
|
||||||
|
let (pattern, values) = self.cs.take_pattern_and_values();
|
||||||
|
CscMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an entry for the given row/col indices, or `None` if the indices are out of bounds.
|
||||||
|
///
|
||||||
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
|
/// stored column entries for the given row.
|
||||||
|
pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<T>> {
|
||||||
|
self.cs.get_entry(row_index, col_index)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable entry for the given row/col indices, or `None` if the indices are out
|
||||||
|
/// of bounds.
|
||||||
|
///
|
||||||
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
|
/// stored column entries for the given row.
|
||||||
|
pub fn get_entry_mut(
|
||||||
|
&mut self,
|
||||||
|
row_index: usize,
|
||||||
|
col_index: usize,
|
||||||
|
) -> Option<SparseEntryMut<T>> {
|
||||||
|
self.cs.get_entry_mut(row_index, col_index)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an entry for the given row/col indices.
|
||||||
|
///
|
||||||
|
/// Same as `get_entry`, except that it directly panics upon encountering row/col indices
|
||||||
|
/// out of bounds.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
/// Panics if `row_index` or `col_index` is out of bounds.
|
||||||
|
pub fn index_entry(&self, row_index: usize, col_index: usize) -> SparseEntry<T> {
|
||||||
|
self.get_entry(row_index, col_index)
|
||||||
|
.expect("Out of bounds matrix indices encountered")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable entry for the given row/col indices.
|
||||||
|
///
|
||||||
|
/// Same as `get_entry_mut`, except that it directly panics upon encountering row/col indices
|
||||||
|
/// out of bounds.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
/// Panics if `row_index` or `col_index` is out of bounds.
|
||||||
|
pub fn index_entry_mut(&mut self, row_index: usize, col_index: usize) -> SparseEntryMut<T> {
|
||||||
|
self.get_entry_mut(row_index, col_index)
|
||||||
|
.expect("Out of bounds matrix indices encountered")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data.
|
||||||
|
pub fn csr_data(&self) -> (&[usize], &[usize], &[T]) {
|
||||||
|
self.cs.cs_data()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data,
|
||||||
|
/// where the `values` array is mutable.
|
||||||
|
pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||||
|
self.cs.cs_data_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a sparse matrix that contains only the explicit entries decided by the
|
||||||
|
/// given predicate.
|
||||||
|
pub fn filter<P>(&self, predicate: P) -> Self
|
||||||
|
where
|
||||||
|
T: Clone,
|
||||||
|
P: Fn(usize, usize, &T) -> bool,
|
||||||
|
{
|
||||||
|
Self {
|
||||||
|
cs: self
|
||||||
|
.cs
|
||||||
|
.filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a new matrix representing the upper triangular part of this matrix.
|
||||||
|
///
|
||||||
|
/// The result includes the diagonal of the matrix.
|
||||||
|
pub fn upper_triangle(&self) -> Self
|
||||||
|
where
|
||||||
|
T: Clone,
|
||||||
|
{
|
||||||
|
self.filter(|i, j, _| i <= j)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a new matrix representing the lower triangular part of this matrix.
|
||||||
|
///
|
||||||
|
/// The result includes the diagonal of the matrix.
|
||||||
|
pub fn lower_triangle(&self) -> Self
|
||||||
|
where
|
||||||
|
T: Clone,
|
||||||
|
{
|
||||||
|
self.filter(|i, j, _| i >= j)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the diagonal of the matrix as a sparse matrix.
|
||||||
|
pub fn diagonal_as_csr(&self) -> Self
|
||||||
|
where
|
||||||
|
T: Clone,
|
||||||
|
{
|
||||||
|
Self {
|
||||||
|
cs: self.cs.diagonal_as_matrix(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute the transpose of the matrix.
|
||||||
|
pub fn transpose(&self) -> CsrMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar,
|
||||||
|
{
|
||||||
|
CscMatrix::from(self).transpose_as_csr()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert pattern format errors into more meaningful CSR-specific errors.
|
||||||
|
///
|
||||||
|
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
||||||
|
/// not lanes, major and minor dimensions.
|
||||||
|
fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseFormatError {
|
||||||
|
use SparseFormatError as E;
|
||||||
|
use SparseFormatErrorKind as K;
|
||||||
|
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||||
|
use SparsityPatternFormatError::*;
|
||||||
|
|
||||||
|
match err {
|
||||||
|
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||||
|
K::InvalidStructure,
|
||||||
|
"Length of row offset array is not equal to nrows + 1.",
|
||||||
|
),
|
||||||
|
InvalidOffsetFirstLast => E::from_kind_and_msg(
|
||||||
|
K::InvalidStructure,
|
||||||
|
"First or last row offset is inconsistent with format specification.",
|
||||||
|
),
|
||||||
|
NonmonotonicOffsets => E::from_kind_and_msg(
|
||||||
|
K::InvalidStructure,
|
||||||
|
"Row offsets are not monotonically increasing.",
|
||||||
|
),
|
||||||
|
NonmonotonicMinorIndices => E::from_kind_and_msg(
|
||||||
|
K::InvalidStructure,
|
||||||
|
"Column indices are not monotonically increasing (sorted) within each row.",
|
||||||
|
),
|
||||||
|
MinorIndexOutOfBounds => {
|
||||||
|
E::from_kind_and_msg(K::IndexOutOfBounds, "Column indices are out of bounds.")
|
||||||
|
}
|
||||||
|
PatternDuplicateEntry => {
|
||||||
|
E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Iterator type for iterating over triplets in a CSR matrix.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CsrTripletIter<'a, T> {
|
||||||
|
pattern_iter: SparsityPatternIter<'a>,
|
||||||
|
values_iter: Iter<'a, T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: Clone> CsrTripletIter<'a, T> {
|
||||||
|
/// Adapts the triplet iterator to return owned values.
|
||||||
|
///
|
||||||
|
/// The triplet iterator returns references to the values. This method adapts the iterator
|
||||||
|
/// so that the values are cloned.
|
||||||
|
#[inline]
|
||||||
|
pub fn cloned_values(self) -> impl 'a + Iterator<Item = (usize, usize, T)> {
|
||||||
|
self.map(|(i, j, v)| (i, j, v.clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CsrTripletIter<'a, T> {
|
||||||
|
type Item = (usize, usize, &'a T);
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let next_entry = self.pattern_iter.next();
|
||||||
|
let next_value = self.values_iter.next();
|
||||||
|
|
||||||
|
match (next_entry, next_value) {
|
||||||
|
(Some((i, j)), Some(v)) => Some((i, j, v)),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Iterator type for mutably iterating over triplets in a CSR matrix.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CsrTripletIterMut<'a, T> {
|
||||||
|
pattern_iter: SparsityPatternIter<'a>,
|
||||||
|
values_mut_iter: IterMut<'a, T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
|
||||||
|
type Item = (usize, usize, &'a mut T);
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let next_entry = self.pattern_iter.next();
|
||||||
|
let next_value = self.values_mut_iter.next();
|
||||||
|
|
||||||
|
match (next_entry, next_value) {
|
||||||
|
(Some((i, j)), Some(v)) => Some((i, j, v)),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An immutable representation of a row in a CSR matrix.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct CsrRow<'a, T> {
|
||||||
|
lane: CsLane<'a, T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A mutable representation of a row in a CSR matrix.
|
||||||
|
///
|
||||||
|
/// Note that only explicitly stored entries can be mutated. The sparsity pattern belonging
|
||||||
|
/// to the row cannot be modified.
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub struct CsrRowMut<'a, T> {
|
||||||
|
lane: CsLaneMut<'a, T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implement the methods common to both CsrRow and CsrRowMut
|
||||||
|
macro_rules! impl_csr_row_common_methods {
|
||||||
|
($name:ty) => {
|
||||||
|
impl<'a, T> $name {
|
||||||
|
/// The number of global columns in the row.
|
||||||
|
#[inline]
|
||||||
|
pub fn ncols(&self) -> usize {
|
||||||
|
self.lane.minor_dim()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of non-zeros in this row.
|
||||||
|
#[inline]
|
||||||
|
pub fn nnz(&self) -> usize {
|
||||||
|
self.lane.nnz()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The column indices corresponding to explicitly stored entries in this row.
|
||||||
|
#[inline]
|
||||||
|
pub fn col_indices(&self) -> &[usize] {
|
||||||
|
self.lane.minor_indices()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The values corresponding to explicitly stored entries in this row.
|
||||||
|
#[inline]
|
||||||
|
pub fn values(&self) -> &[T] {
|
||||||
|
self.lane.values()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an entry for the given global column index.
|
||||||
|
///
|
||||||
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
|
/// stored column entries.
|
||||||
|
#[inline]
|
||||||
|
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
|
||||||
|
self.lane.get_entry(global_col_index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_csr_row_common_methods!(CsrRow<'a, T>);
|
||||||
|
impl_csr_row_common_methods!(CsrRowMut<'a, T>);
|
||||||
|
|
||||||
|
impl<'a, T> CsrRowMut<'a, T> {
|
||||||
|
/// Mutable access to the values corresponding to explicitly stored entries in this row.
|
||||||
|
#[inline]
|
||||||
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
|
self.lane.values_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Provides simultaneous access to column indices and mutable values corresponding to the
|
||||||
|
/// explicitly stored entries in this row.
|
||||||
|
///
|
||||||
|
/// This method primarily facilitates low-level access for methods that process data stored
|
||||||
|
/// in CSR format directly.
|
||||||
|
#[inline]
|
||||||
|
pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
||||||
|
self.lane.indices_and_values_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable entry for the given global column index.
|
||||||
|
#[inline]
|
||||||
|
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<T>> {
|
||||||
|
self.lane.get_entry_mut(global_col_index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
||||||
|
pub struct CsrRowIter<'a, T> {
|
||||||
|
lane_iter: CsLaneIter<'a, T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CsrRowIter<'a, T> {
|
||||||
|
type Item = CsrRow<'a, T>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
self.lane_iter.next().map(|lane| CsrRow { lane })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mutable row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
||||||
|
pub struct CsrRowIterMut<'a, T> {
|
||||||
|
lane_iter: CsLaneIterMut<'a, T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CsrRowIterMut<'a, T>
|
||||||
|
where
|
||||||
|
T: 'a,
|
||||||
|
{
|
||||||
|
type Item = CsrRowMut<'a, T>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
self.lane_iter.next().map(|lane| CsrRowMut { lane })
|
||||||
|
}
|
||||||
|
}
|
373
nalgebra-sparse/src/factorization/cholesky.rs
Normal file
373
nalgebra-sparse/src/factorization/cholesky.rs
Normal file
@ -0,0 +1,373 @@
|
|||||||
|
use crate::csc::CscMatrix;
|
||||||
|
use crate::ops::serial::spsolve_csc_lower_triangular;
|
||||||
|
use crate::ops::Op;
|
||||||
|
use crate::pattern::SparsityPattern;
|
||||||
|
use core::{iter, mem};
|
||||||
|
use nalgebra::{DMatrix, DMatrixSlice, DMatrixSliceMut, RealField, Scalar};
|
||||||
|
use std::fmt::{Display, Formatter};
|
||||||
|
|
||||||
|
/// A symbolic sparse Cholesky factorization of a CSC matrix.
|
||||||
|
///
|
||||||
|
/// The symbolic factorization computes the sparsity pattern of `L`, the Cholesky factor.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct CscSymbolicCholesky {
|
||||||
|
// Pattern of the original matrix that was decomposed
|
||||||
|
m_pattern: SparsityPattern,
|
||||||
|
l_pattern: SparsityPattern,
|
||||||
|
// u in this context is L^T, so that M = L L^T
|
||||||
|
u_pattern: SparsityPattern,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CscSymbolicCholesky {
|
||||||
|
/// Compute the symbolic factorization for a sparsity pattern belonging to a CSC matrix.
|
||||||
|
///
|
||||||
|
/// The sparsity pattern must be symmetric. However, this is not enforced, and it is the
|
||||||
|
/// responsibility of the user to ensure that this property holds.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the sparsity pattern is not square.
|
||||||
|
pub fn factor(pattern: SparsityPattern) -> Self {
|
||||||
|
assert_eq!(
|
||||||
|
pattern.major_dim(),
|
||||||
|
pattern.minor_dim(),
|
||||||
|
"Major and minor dimensions must be the same (square matrix)."
|
||||||
|
);
|
||||||
|
let (l_pattern, u_pattern) = nonzero_pattern(&pattern);
|
||||||
|
Self {
|
||||||
|
m_pattern: pattern,
|
||||||
|
l_pattern,
|
||||||
|
u_pattern,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The pattern of the Cholesky factor `L`.
|
||||||
|
pub fn l_pattern(&self) -> &SparsityPattern {
|
||||||
|
&self.l_pattern
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A sparse Cholesky factorization `A = L L^T` of a [`CscMatrix`].
|
||||||
|
///
|
||||||
|
/// The factor `L` is a sparse, lower-triangular matrix. See the article on [Wikipedia] for
|
||||||
|
/// more information.
|
||||||
|
///
|
||||||
|
/// The implementation is a port of the `CsCholesky` implementation in `nalgebra`. It is similar
|
||||||
|
/// to Tim Davis' [`CSparse`]. The current implementation performs no fill-in reduction, and can
|
||||||
|
/// therefore be expected to produce much too dense Cholesky factors for many matrices.
|
||||||
|
/// It is therefore not currently recommended to use this implementation for serious projects.
|
||||||
|
///
|
||||||
|
/// [`CSparse`]: https://epubs.siam.org/doi/book/10.1137/1.9780898718881
|
||||||
|
/// [Wikipedia]: https://en.wikipedia.org/wiki/Cholesky_decomposition
|
||||||
|
// TODO: We should probably implement PartialEq/Eq, but in that case we'd probably need a
|
||||||
|
// custom implementation, due to the need to exclude the workspace arrays
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CscCholesky<T> {
|
||||||
|
// Pattern of the original matrix
|
||||||
|
m_pattern: SparsityPattern,
|
||||||
|
l_factor: CscMatrix<T>,
|
||||||
|
u_pattern: SparsityPattern,
|
||||||
|
work_x: Vec<T>,
|
||||||
|
work_c: Vec<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
/// Possible errors produced by the Cholesky factorization.
|
||||||
|
pub enum CholeskyError {
|
||||||
|
/// The matrix is not positive definite.
|
||||||
|
NotPositiveDefinite,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for CholeskyError {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "Matrix is not positive definite")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for CholeskyError {}
|
||||||
|
|
||||||
|
impl<T: RealField> CscCholesky<T> {
|
||||||
|
/// Computes the numerical Cholesky factorization associated with the given
|
||||||
|
/// symbolic factorization and the provided values.
|
||||||
|
///
|
||||||
|
/// The values correspond to the non-zero values of the CSC matrix for which the
|
||||||
|
/// symbolic factorization was computed.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the numerical factorization fails. This can occur if the matrix is not
|
||||||
|
/// symmetric positive definite.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the number of values differ from the number of non-zeros of the sparsity pattern
|
||||||
|
/// of the matrix that was symbolically factored.
|
||||||
|
pub fn factor_numerical(
|
||||||
|
symbolic: CscSymbolicCholesky,
|
||||||
|
values: &[T],
|
||||||
|
) -> Result<Self, CholeskyError> {
|
||||||
|
assert_eq!(
|
||||||
|
symbolic.l_pattern.nnz(),
|
||||||
|
symbolic.u_pattern.nnz(),
|
||||||
|
"u is just the transpose of l, so should have the same nnz"
|
||||||
|
);
|
||||||
|
|
||||||
|
let l_nnz = symbolic.l_pattern.nnz();
|
||||||
|
let l_values = vec![T::zero(); l_nnz];
|
||||||
|
let l_factor =
|
||||||
|
CscMatrix::try_from_pattern_and_values(symbolic.l_pattern, l_values).unwrap();
|
||||||
|
|
||||||
|
let (nrows, ncols) = (l_factor.nrows(), l_factor.ncols());
|
||||||
|
|
||||||
|
let mut factorization = CscCholesky {
|
||||||
|
m_pattern: symbolic.m_pattern,
|
||||||
|
l_factor,
|
||||||
|
u_pattern: symbolic.u_pattern,
|
||||||
|
work_x: vec![T::zero(); nrows],
|
||||||
|
// Fill with MAX so that things hopefully totally fail if values are not
|
||||||
|
// overwritten. Might be easier to debug this way
|
||||||
|
work_c: vec![usize::MAX, ncols],
|
||||||
|
};
|
||||||
|
|
||||||
|
factorization.refactor(values)?;
|
||||||
|
Ok(factorization)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes the Cholesky factorization of the provided matrix.
|
||||||
|
///
|
||||||
|
/// The matrix must be symmetric positive definite. Symmetry is not checked, and it is up
|
||||||
|
/// to the user to enforce this property.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the numerical factorization fails. This can occur if the matrix is not
|
||||||
|
/// symmetric positive definite.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the matrix is not square.
|
||||||
|
pub fn factor(matrix: &CscMatrix<T>) -> Result<Self, CholeskyError> {
|
||||||
|
let symbolic = CscSymbolicCholesky::factor(matrix.pattern().clone());
|
||||||
|
Self::factor_numerical(symbolic, matrix.values())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Re-computes the factorization for a new set of non-zero values.
|
||||||
|
///
|
||||||
|
/// This is useful when the values of a matrix changes, but the sparsity pattern remains
|
||||||
|
/// constant.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the numerical factorization fails. This can occur if the matrix is not
|
||||||
|
/// symmetric positive definite.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the number of values does not match the number of non-zeros in the sparsity
|
||||||
|
/// pattern.
|
||||||
|
pub fn refactor(&mut self, values: &[T]) -> Result<(), CholeskyError> {
|
||||||
|
self.decompose_left_looking(values)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a reference to the Cholesky factor `L`.
|
||||||
|
pub fn l(&self) -> &CscMatrix<T> {
|
||||||
|
&self.l_factor
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the Cholesky factor `L`.
|
||||||
|
pub fn take_l(self) -> CscMatrix<T> {
|
||||||
|
self.l_factor
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform a numerical left-looking cholesky decomposition of a matrix with the same structure as the
|
||||||
|
/// one used to initialize `self`, but with different non-zero values provided by `values`.
|
||||||
|
fn decompose_left_looking(&mut self, values: &[T]) -> Result<(), CholeskyError> {
|
||||||
|
assert!(
|
||||||
|
values.len() >= self.m_pattern.nnz(),
|
||||||
|
// TODO: Improve error message
|
||||||
|
"The set of values is too small."
|
||||||
|
);
|
||||||
|
|
||||||
|
let n = self.l_factor.nrows();
|
||||||
|
|
||||||
|
// Reset `work_c` to the column pointers of `l`.
|
||||||
|
self.work_c.clear();
|
||||||
|
self.work_c.extend_from_slice(self.l_factor.col_offsets());
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
for k in 0..n {
|
||||||
|
// Scatter the k-th column of the original matrix with the values provided.
|
||||||
|
let range_begin = *self.m_pattern.major_offsets().get_unchecked(k);
|
||||||
|
let range_end = *self.m_pattern.major_offsets().get_unchecked(k + 1);
|
||||||
|
let range_k = range_begin..range_end;
|
||||||
|
|
||||||
|
*self.work_x.get_unchecked_mut(k) = T::zero();
|
||||||
|
for p in range_k.clone() {
|
||||||
|
let irow = *self.m_pattern.minor_indices().get_unchecked(p);
|
||||||
|
|
||||||
|
if irow >= k {
|
||||||
|
*self.work_x.get_unchecked_mut(irow) = *values.get_unchecked(p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for &j in self.u_pattern.lane(k) {
|
||||||
|
let factor = -*self
|
||||||
|
.l_factor
|
||||||
|
.values()
|
||||||
|
.get_unchecked(*self.work_c.get_unchecked(j));
|
||||||
|
*self.work_c.get_unchecked_mut(j) += 1;
|
||||||
|
|
||||||
|
if j < k {
|
||||||
|
let col_j = self.l_factor.col(j);
|
||||||
|
let col_j_entries = col_j.row_indices().iter().zip(col_j.values());
|
||||||
|
for (&z, val) in col_j_entries {
|
||||||
|
if z >= k {
|
||||||
|
*self.work_x.get_unchecked_mut(z) += val.inlined_clone() * factor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let diag = *self.work_x.get_unchecked(k);
|
||||||
|
|
||||||
|
if diag > T::zero() {
|
||||||
|
let denom = diag.sqrt();
|
||||||
|
|
||||||
|
{
|
||||||
|
let (offsets, _, values) = self.l_factor.csc_data_mut();
|
||||||
|
*values.get_unchecked_mut(*offsets.get_unchecked(k)) = denom;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut col_k = self.l_factor.col_mut(k);
|
||||||
|
let (col_k_rows, col_k_values) = col_k.rows_and_values_mut();
|
||||||
|
let col_k_entries = col_k_rows.iter().zip(col_k_values);
|
||||||
|
for (&p, val) in col_k_entries {
|
||||||
|
*val = *self.work_x.get_unchecked(p) / denom;
|
||||||
|
*self.work_x.get_unchecked_mut(p) = T::zero();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(CholeskyError::NotPositiveDefinite);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Solves the system `A X = B`, where `X` and `B` are dense matrices.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if `B` is not square.
|
||||||
|
pub fn solve<'a>(&'a self, b: impl Into<DMatrixSlice<'a, T>>) -> DMatrix<T> {
|
||||||
|
let b = b.into();
|
||||||
|
let mut output = b.clone_owned();
|
||||||
|
self.solve_mut(&mut output);
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Solves the system `AX = B`, where `X` and `B` are dense matrices.
|
||||||
|
///
|
||||||
|
/// The result is stored in-place in `b`.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if `b` is not square.
|
||||||
|
pub fn solve_mut<'a>(&'a self, b: impl Into<DMatrixSliceMut<'a, T>>) {
|
||||||
|
let expect_msg = "If the Cholesky factorization succeeded,\
|
||||||
|
then the triangular solve should never fail";
|
||||||
|
// Solve LY = B
|
||||||
|
let mut y = b.into();
|
||||||
|
spsolve_csc_lower_triangular(Op::NoOp(self.l()), &mut y).expect(expect_msg);
|
||||||
|
|
||||||
|
// Solve L^T X = Y
|
||||||
|
let mut x = y;
|
||||||
|
spsolve_csc_lower_triangular(Op::Transpose(self.l()), &mut x).expect(expect_msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reach(
|
||||||
|
pattern: &SparsityPattern,
|
||||||
|
j: usize,
|
||||||
|
max_j: usize,
|
||||||
|
tree: &[usize],
|
||||||
|
marks: &mut Vec<bool>,
|
||||||
|
out: &mut Vec<usize>,
|
||||||
|
) {
|
||||||
|
marks.clear();
|
||||||
|
marks.resize(tree.len(), false);
|
||||||
|
|
||||||
|
// TODO: avoid all those allocations.
|
||||||
|
let mut tmp = Vec::new();
|
||||||
|
let mut res = Vec::new();
|
||||||
|
|
||||||
|
for &irow in pattern.lane(j) {
|
||||||
|
let mut curr = irow;
|
||||||
|
while curr != usize::max_value() && curr <= max_j && !marks[curr] {
|
||||||
|
marks[curr] = true;
|
||||||
|
tmp.push(curr);
|
||||||
|
curr = tree[curr];
|
||||||
|
}
|
||||||
|
|
||||||
|
tmp.append(&mut res);
|
||||||
|
mem::swap(&mut tmp, &mut res);
|
||||||
|
}
|
||||||
|
|
||||||
|
res.sort_unstable();
|
||||||
|
|
||||||
|
out.append(&mut res);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn nonzero_pattern(m: &SparsityPattern) -> (SparsityPattern, SparsityPattern) {
|
||||||
|
let etree = elimination_tree(m);
|
||||||
|
// Note: We assume CSC, therefore rows == minor and cols == major
|
||||||
|
let (nrows, ncols) = (m.minor_dim(), m.major_dim());
|
||||||
|
let mut rows = Vec::with_capacity(m.nnz());
|
||||||
|
let mut col_offsets = Vec::with_capacity(ncols + 1);
|
||||||
|
let mut marks = Vec::new();
|
||||||
|
|
||||||
|
// NOTE: the following will actually compute the non-zero pattern of
|
||||||
|
// the transpose of l.
|
||||||
|
col_offsets.push(0);
|
||||||
|
for i in 0..nrows {
|
||||||
|
reach(m, i, i, &etree, &mut marks, &mut rows);
|
||||||
|
col_offsets.push(rows.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
let u_pattern =
|
||||||
|
SparsityPattern::try_from_offsets_and_indices(nrows, ncols, col_offsets, rows).unwrap();
|
||||||
|
|
||||||
|
// TODO: Avoid this transpose?
|
||||||
|
let l_pattern = u_pattern.transpose();
|
||||||
|
|
||||||
|
(l_pattern, u_pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn elimination_tree(pattern: &SparsityPattern) -> Vec<usize> {
|
||||||
|
// Note: The pattern is assumed to of a CSC matrix, so the number of rows is
|
||||||
|
// given by the minor dimension
|
||||||
|
let nrows = pattern.minor_dim();
|
||||||
|
let mut forest: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect();
|
||||||
|
let mut ancestor: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect();
|
||||||
|
|
||||||
|
for k in 0..nrows {
|
||||||
|
for &irow in pattern.lane(k) {
|
||||||
|
let mut i = irow;
|
||||||
|
|
||||||
|
while i < k {
|
||||||
|
let i_ancestor = ancestor[i];
|
||||||
|
ancestor[i] = k;
|
||||||
|
|
||||||
|
if i_ancestor == usize::max_value() {
|
||||||
|
forest[i] = k;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
i = i_ancestor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
forest
|
||||||
|
}
|
6
nalgebra-sparse/src/factorization/mod.rs
Normal file
6
nalgebra-sparse/src/factorization/mod.rs
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
//! Matrix factorization for sparse matrices.
|
||||||
|
//!
|
||||||
|
//! Currently, the only factorization provided here is the [`CscCholesky`] factorization.
|
||||||
|
mod cholesky;
|
||||||
|
|
||||||
|
pub use cholesky::*;
|
267
nalgebra-sparse/src/lib.rs
Normal file
267
nalgebra-sparse/src/lib.rs
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
//! Sparse matrices and algorithms for [nalgebra](https://www.nalgebra.org).
|
||||||
|
//!
|
||||||
|
//! This crate extends `nalgebra` with sparse matrix formats and operations on sparse matrices.
|
||||||
|
//!
|
||||||
|
//! ## Goals
|
||||||
|
//! The long-term goals for this crate are listed below.
|
||||||
|
//!
|
||||||
|
//! - Provide proven sparse matrix formats in an easy-to-use and idiomatic Rust API that
|
||||||
|
//! naturally integrates with `nalgebra`.
|
||||||
|
//! - Provide additional expert-level APIs for fine-grained control over operations.
|
||||||
|
//! - Integrate well with external sparse matrix libraries.
|
||||||
|
//! - Provide native Rust high-performance routines, including parallel matrix operations.
|
||||||
|
//!
|
||||||
|
//! ## Highlighted current features
|
||||||
|
//!
|
||||||
|
//! - [CSR](csr::CsrMatrix), [CSC](csc::CscMatrix) and [COO](coo::CooMatrix) formats, and
|
||||||
|
//! [conversions](`convert`) between them.
|
||||||
|
//! - Common arithmetic operations are implemented. See the [`ops`] module.
|
||||||
|
//! - Sparsity patterns in CSR and CSC matrices are explicitly represented by the
|
||||||
|
//! [SparsityPattern](pattern::SparsityPattern) type, which encodes the invariants of the
|
||||||
|
//! associated index data structures.
|
||||||
|
//! - [proptest strategies](`proptest`) for sparse matrices when the feature
|
||||||
|
//! `proptest-support` is enabled.
|
||||||
|
//! - [matrixcompare support](https://crates.io/crates/matrixcompare) for effortless
|
||||||
|
//! (approximate) comparison of matrices in test code (requires the `compare` feature).
|
||||||
|
//!
|
||||||
|
//! ## Current state
|
||||||
|
//!
|
||||||
|
//! The library is in an early, but usable state. The API has been designed to be extensible,
|
||||||
|
//! but breaking changes will be necessary to implement several planned features. While it is
|
||||||
|
//! backed by an extensive test suite, it has yet to be thoroughly battle-tested in real
|
||||||
|
//! applications. Moreover, the focus so far has been on correctness and API design, with little
|
||||||
|
//! focus on performance. Future improvements will include incremental performance enhancements.
|
||||||
|
//!
|
||||||
|
//! Current limitations:
|
||||||
|
//!
|
||||||
|
//! - Limited or no availability of sparse system solvers.
|
||||||
|
//! - Limited support for complex numbers. Currently only arithmetic operations that do not
|
||||||
|
//! rely on particular properties of complex numbers, such as e.g. conjugation, are
|
||||||
|
//! supported.
|
||||||
|
//! - No integration with external libraries.
|
||||||
|
//!
|
||||||
|
//! # Usage
|
||||||
|
//!
|
||||||
|
//! Add the following to your `Cargo.toml` file:
|
||||||
|
//!
|
||||||
|
//! ```toml
|
||||||
|
//! [dependencies]
|
||||||
|
//! nalgebra_sparse = "0.1"
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! # Supported matrix formats
|
||||||
|
//!
|
||||||
|
//! | Format | Notes |
|
||||||
|
//! | ------------------------|--------------------------------------------- |
|
||||||
|
//! | [COO](`coo::CooMatrix`) | Well-suited for matrix construction. <br /> Ill-suited for algebraic operations. |
|
||||||
|
//! | [CSR](`csr::CsrMatrix`) | Immutable sparsity pattern, suitable for algebraic operations. <br /> Fast row access. |
|
||||||
|
//! | [CSC](`csc::CscMatrix`) | Immutable sparsity pattern, suitable for algebraic operations. <br /> Fast column access. |
|
||||||
|
//!
|
||||||
|
//! What format is best to use depends on the application. The most common use case for sparse
|
||||||
|
//! matrices in science is the solution of sparse linear systems. Here we can differentiate between
|
||||||
|
//! two common cases:
|
||||||
|
//!
|
||||||
|
//! - Direct solvers. Typically, direct solvers take their input in CSR or CSC format.
|
||||||
|
//! - Iterative solvers. Many iterative solvers require only matrix-vector products,
|
||||||
|
//! for which the CSR or CSC formats are suitable.
|
||||||
|
//!
|
||||||
|
//! The [COO](coo::CooMatrix) format is primarily intended for matrix construction.
|
||||||
|
//! A common pattern is to use COO for construction, before converting to CSR or CSC for use
|
||||||
|
//! in a direct solver or for computing matrix-vector products in an iterative solver.
|
||||||
|
//! Some high-performance applications might also directly manipulate the CSR and/or CSC
|
||||||
|
//! formats.
|
||||||
|
//!
|
||||||
|
//! # Example: COO -> CSR -> matrix-vector product
|
||||||
|
//!
|
||||||
|
//! ```rust
|
||||||
|
//! use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix};
|
||||||
|
//! use nalgebra::{DMatrix, DVector};
|
||||||
|
//! use matrixcompare::assert_matrix_eq;
|
||||||
|
//!
|
||||||
|
//! // The dense representation of the matrix
|
||||||
|
//! let dense = DMatrix::from_row_slice(3, 3,
|
||||||
|
//! &[1.0, 0.0, 3.0,
|
||||||
|
//! 2.0, 0.0, 1.3,
|
||||||
|
//! 0.0, 0.0, 4.1]);
|
||||||
|
//!
|
||||||
|
//! // Build the equivalent COO representation. We only add the non-zero values
|
||||||
|
//! let mut coo = CooMatrix::new(3, 3);
|
||||||
|
//! // We can add elements in any order. For clarity, we do so in row-major order here.
|
||||||
|
//! coo.push(0, 0, 1.0);
|
||||||
|
//! coo.push(0, 2, 3.0);
|
||||||
|
//! coo.push(1, 0, 2.0);
|
||||||
|
//! coo.push(1, 2, 1.3);
|
||||||
|
//! coo.push(2, 2, 4.1);
|
||||||
|
//!
|
||||||
|
//! // The simplest way to construct a CSR matrix is to first construct a COO matrix, and
|
||||||
|
//! // then convert it to CSR. The `From` trait is implemented for conversions between different
|
||||||
|
//! // sparse matrix types.
|
||||||
|
//! // Alternatively, we can construct a matrix directly from the CSR data.
|
||||||
|
//! // See the docs for CsrMatrix for how to do that.
|
||||||
|
//! let csr = CsrMatrix::from(&coo);
|
||||||
|
//!
|
||||||
|
//! // Let's check that the CSR matrix and the dense matrix represent the same matrix.
|
||||||
|
//! // We can use macros from the `matrixcompare` crate to easily do this, despite the fact that
|
||||||
|
//! // we're comparing across two different matrix formats. Note that these macros are only really
|
||||||
|
//! // appropriate for writing tests, however.
|
||||||
|
//! assert_matrix_eq!(csr, dense);
|
||||||
|
//!
|
||||||
|
//! let x = DVector::from_column_slice(&[1.3, -4.0, 3.5]);
|
||||||
|
//!
|
||||||
|
//! // Compute the matrix-vector product y = A * x. We don't need to specify the type here,
|
||||||
|
//! // but let's just do it to make sure we get what we expect
|
||||||
|
//! let y: DVector<_> = &csr * &x;
|
||||||
|
//!
|
||||||
|
//! // Verify the result with a small element-wise absolute tolerance
|
||||||
|
//! let y_expected = DVector::from_column_slice(&[11.8, 7.15, 14.35]);
|
||||||
|
//! assert_matrix_eq!(y, y_expected, comp = abs, tol = 1e-9);
|
||||||
|
//!
|
||||||
|
//! // The above expression is simple, and gives easy to read code, but if we're doing this in a
|
||||||
|
//! // loop, we'll have to keep allocating new vectors. If we determine that this is a bottleneck,
|
||||||
|
//! // then we can resort to the lower level APIs for more control over the operations
|
||||||
|
//! {
|
||||||
|
//! use nalgebra_sparse::ops::{Op, serial::spmm_csr_dense};
|
||||||
|
//! let mut y = y;
|
||||||
|
//! // Compute y <- 0.0 * y + 1.0 * csr * dense. We store the result directly in `y`, without
|
||||||
|
//! // any intermediate allocations
|
||||||
|
//! spmm_csr_dense(0.0, &mut y, 1.0, Op::NoOp(&csr), Op::NoOp(&x));
|
||||||
|
//! assert_matrix_eq!(y, y_expected, comp = abs, tol = 1e-9);
|
||||||
|
//! }
|
||||||
|
//! ```
|
||||||
|
#![deny(non_camel_case_types)]
|
||||||
|
#![deny(unused_parens)]
|
||||||
|
#![deny(non_upper_case_globals)]
|
||||||
|
#![deny(unused_qualifications)]
|
||||||
|
#![deny(unused_results)]
|
||||||
|
#![deny(missing_docs)]
|
||||||
|
|
||||||
|
pub extern crate nalgebra as na;
|
||||||
|
pub mod convert;
|
||||||
|
pub mod coo;
|
||||||
|
pub mod csc;
|
||||||
|
pub mod csr;
|
||||||
|
pub mod factorization;
|
||||||
|
pub mod ops;
|
||||||
|
pub mod pattern;
|
||||||
|
|
||||||
|
pub(crate) mod cs;
|
||||||
|
|
||||||
|
#[cfg(feature = "proptest-support")]
|
||||||
|
pub mod proptest;
|
||||||
|
|
||||||
|
#[cfg(feature = "compare")]
|
||||||
|
mod matrixcompare;
|
||||||
|
|
||||||
|
use num_traits::Zero;
|
||||||
|
use std::error::Error;
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
pub use self::coo::CooMatrix;
|
||||||
|
pub use self::csc::CscMatrix;
|
||||||
|
pub use self::csr::CsrMatrix;
|
||||||
|
|
||||||
|
/// Errors produced by functions that expect well-formed sparse format data.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct SparseFormatError {
|
||||||
|
kind: SparseFormatErrorKind,
|
||||||
|
// Currently we only use an underlying error for generating the `Display` impl
|
||||||
|
error: Box<dyn Error>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SparseFormatError {
|
||||||
|
/// The type of error.
|
||||||
|
pub fn kind(&self) -> &SparseFormatErrorKind {
|
||||||
|
&self.kind
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn from_kind_and_error(kind: SparseFormatErrorKind, error: Box<dyn Error>) -> Self {
|
||||||
|
Self { kind, error }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper functionality for more conveniently creating errors.
|
||||||
|
pub(crate) fn from_kind_and_msg(kind: SparseFormatErrorKind, msg: &'static str) -> Self {
|
||||||
|
Self::from_kind_and_error(kind, Box::<dyn Error>::from(msg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The type of format error described by a [SparseFormatError](struct.SparseFormatError.html).
|
||||||
|
#[non_exhaustive]
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum SparseFormatErrorKind {
|
||||||
|
/// Indicates that the index data associated with the format contains at least one index
|
||||||
|
/// out of bounds.
|
||||||
|
IndexOutOfBounds,
|
||||||
|
|
||||||
|
/// Indicates that the provided data contains at least one duplicate entry, and the
|
||||||
|
/// current format does not support duplicate entries.
|
||||||
|
DuplicateEntry,
|
||||||
|
|
||||||
|
/// Indicates that the provided data for the format does not conform to the high-level
|
||||||
|
/// structure of the format.
|
||||||
|
///
|
||||||
|
/// For example, the arrays defining the format data might have incompatible sizes.
|
||||||
|
InvalidStructure,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for SparseFormatError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
write!(f, "{}", self.error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for SparseFormatError {}
|
||||||
|
|
||||||
|
/// An entry in a sparse matrix.
|
||||||
|
///
|
||||||
|
/// Sparse matrices do not store all their entries explicitly. Therefore, entry (i, j) in the matrix
|
||||||
|
/// can either be a reference to an explicitly stored element, or it is implicitly zero.
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub enum SparseEntry<'a, T> {
|
||||||
|
/// The entry is a reference to an explicitly stored element.
|
||||||
|
///
|
||||||
|
/// Note that the naming here is a misnomer: The element can still be zero, even though it
|
||||||
|
/// is explicitly stored (a so-called "explicit zero").
|
||||||
|
NonZero(&'a T),
|
||||||
|
/// The entry is implicitly zero, i.e. it is not explicitly stored.
|
||||||
|
Zero,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: Clone + Zero> SparseEntry<'a, T> {
|
||||||
|
/// Returns the value represented by this entry.
|
||||||
|
///
|
||||||
|
/// Either clones the underlying reference or returns zero if the entry is not explicitly
|
||||||
|
/// stored.
|
||||||
|
pub fn into_value(self) -> T {
|
||||||
|
match self {
|
||||||
|
SparseEntry::NonZero(value) => value.clone(),
|
||||||
|
SparseEntry::Zero => T::zero(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A mutable entry in a sparse matrix.
|
||||||
|
///
|
||||||
|
/// See also `SparseEntry`.
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub enum SparseEntryMut<'a, T> {
|
||||||
|
/// The entry is a mutable reference to an explicitly stored element.
|
||||||
|
///
|
||||||
|
/// Note that the naming here is a misnomer: The element can still be zero, even though it
|
||||||
|
/// is explicitly stored (a so-called "explicit zero").
|
||||||
|
NonZero(&'a mut T),
|
||||||
|
/// The entry is implicitly zero i.e. it is not explicitly stored.
|
||||||
|
Zero,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
|
||||||
|
/// Returns the value represented by this entry.
|
||||||
|
///
|
||||||
|
/// Either clones the underlying reference or returns zero if the entry is not explicitly
|
||||||
|
/// stored.
|
||||||
|
pub fn into_value(self) -> T {
|
||||||
|
match self {
|
||||||
|
SparseEntryMut::NonZero(value) => value.clone(),
|
||||||
|
SparseEntryMut::Zero => T::zero(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
65
nalgebra-sparse/src/matrixcompare.rs
Normal file
65
nalgebra-sparse/src/matrixcompare.rs
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
//! Implements core traits for use with `matrixcompare`.
|
||||||
|
use crate::coo::CooMatrix;
|
||||||
|
use crate::csc::CscMatrix;
|
||||||
|
use crate::csr::CsrMatrix;
|
||||||
|
use matrixcompare_core;
|
||||||
|
use matrixcompare_core::{Access, SparseAccess};
|
||||||
|
|
||||||
|
macro_rules! impl_matrix_for_csr_csc {
|
||||||
|
($MatrixType:ident) => {
|
||||||
|
impl<T: Clone> SparseAccess<T> for $MatrixType<T> {
|
||||||
|
fn nnz(&self) -> usize {
|
||||||
|
$MatrixType::nnz(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
|
||||||
|
self.triplet_iter()
|
||||||
|
.map(|(i, j, v)| (i, j, v.clone()))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Clone> matrixcompare_core::Matrix<T> for $MatrixType<T> {
|
||||||
|
fn rows(&self) -> usize {
|
||||||
|
self.nrows()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cols(&self) -> usize {
|
||||||
|
self.ncols()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn access(&self) -> Access<T> {
|
||||||
|
Access::Sparse(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_matrix_for_csr_csc!(CsrMatrix);
|
||||||
|
impl_matrix_for_csr_csc!(CscMatrix);
|
||||||
|
|
||||||
|
impl<T: Clone> SparseAccess<T> for CooMatrix<T> {
|
||||||
|
fn nnz(&self) -> usize {
|
||||||
|
CooMatrix::nnz(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
|
||||||
|
self.triplet_iter()
|
||||||
|
.map(|(i, j, v)| (i, j, v.clone()))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Clone> matrixcompare_core::Matrix<T> for CooMatrix<T> {
|
||||||
|
fn rows(&self) -> usize {
|
||||||
|
self.nrows()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cols(&self) -> usize {
|
||||||
|
self.ncols()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn access(&self) -> Access<T> {
|
||||||
|
Access::Sparse(self)
|
||||||
|
}
|
||||||
|
}
|
331
nalgebra-sparse/src/ops/impl_std_ops.rs
Normal file
331
nalgebra-sparse/src/ops/impl_std_ops.rs
Normal file
@ -0,0 +1,331 @@
|
|||||||
|
use crate::csc::CscMatrix;
|
||||||
|
use crate::csr::CsrMatrix;
|
||||||
|
|
||||||
|
use crate::ops::serial::{
|
||||||
|
spadd_csc_prealloc, spadd_csr_prealloc, spadd_pattern, spmm_csc_dense, spmm_csc_pattern,
|
||||||
|
spmm_csc_prealloc, spmm_csr_dense, spmm_csr_pattern, spmm_csr_prealloc,
|
||||||
|
};
|
||||||
|
use crate::ops::Op;
|
||||||
|
use nalgebra::allocator::Allocator;
|
||||||
|
use nalgebra::base::storage::Storage;
|
||||||
|
use nalgebra::constraint::{DimEq, ShapeConstraint};
|
||||||
|
use nalgebra::{
|
||||||
|
ClosedAdd, ClosedDiv, ClosedMul, ClosedSub, DefaultAllocator, Dim, Dynamic, Matrix, MatrixMN,
|
||||||
|
Scalar, U1,
|
||||||
|
};
|
||||||
|
use num_traits::{One, Zero};
|
||||||
|
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Neg, Sub};
|
||||||
|
|
||||||
|
/// Helper macro for implementing binary operators for different matrix types
|
||||||
|
/// See below for usage.
|
||||||
|
macro_rules! impl_bin_op {
|
||||||
|
($trait:ident, $method:ident,
|
||||||
|
<$($life:lifetime),* $(,)? $($scalar_type:ident $(: $bounds:path)?)?>($a:ident : $a_type:ty, $b:ident : $b_type:ty) -> $ret:ty $body:block)
|
||||||
|
=>
|
||||||
|
{
|
||||||
|
impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type
|
||||||
|
where
|
||||||
|
// Note: The Neg bound is currently required because we delegate e.g.
|
||||||
|
// Sub to SpAdd with negative coefficients. This is not well-defined for
|
||||||
|
// unsigned data types.
|
||||||
|
$($scalar_type: $($bounds + )? Scalar + ClosedAdd + ClosedSub + ClosedMul + Zero + One + Neg<Output=T>)?
|
||||||
|
{
|
||||||
|
type Output = $ret;
|
||||||
|
fn $method(self, $b: $b_type) -> Self::Output {
|
||||||
|
let $a = self;
|
||||||
|
$body
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implements a +/- b for all combinations of reference and owned matrices, for
|
||||||
|
/// CsrMatrix or CscMatrix.
|
||||||
|
macro_rules! impl_sp_plus_minus {
|
||||||
|
// We first match on some special-case syntax, and forward to the actual implementation
|
||||||
|
($matrix_type:ident, $spadd_fn:ident, +) => {
|
||||||
|
impl_sp_plus_minus!(Add, add, $matrix_type, $spadd_fn, +, T::one());
|
||||||
|
};
|
||||||
|
($matrix_type:ident, $spadd_fn:ident, -) => {
|
||||||
|
impl_sp_plus_minus!(Sub, sub, $matrix_type, $spadd_fn, -, -T::one());
|
||||||
|
};
|
||||||
|
($trait:ident, $method:ident, $matrix_type:ident, $spadd_fn:ident, $sign:tt, $factor:expr) => {
|
||||||
|
impl_bin_op!($trait, $method,
|
||||||
|
<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||||
|
// If both matrices have the same pattern, then we can immediately re-use it
|
||||||
|
let pattern = spadd_pattern(a.pattern(), b.pattern());
|
||||||
|
let values = vec![T::zero(); pattern.nnz()];
|
||||||
|
// We are giving data that is valid by definition, so it is safe to unwrap below
|
||||||
|
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
|
||||||
|
.unwrap();
|
||||||
|
$spadd_fn(T::zero(), &mut result, T::one(), Op::NoOp(&a)).unwrap();
|
||||||
|
$spadd_fn(T::one(), &mut result, $factor * T::one(), Op::NoOp(&b)).unwrap();
|
||||||
|
result
|
||||||
|
});
|
||||||
|
|
||||||
|
impl_bin_op!($trait, $method,
|
||||||
|
<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||||
|
&a $sign b
|
||||||
|
});
|
||||||
|
|
||||||
|
impl_bin_op!($trait, $method,
|
||||||
|
<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||||
|
a $sign &b
|
||||||
|
});
|
||||||
|
impl_bin_op!($trait, $method, <T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||||
|
a $sign &b
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, +);
|
||||||
|
impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, -);
|
||||||
|
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, +);
|
||||||
|
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, -);
|
||||||
|
|
||||||
|
macro_rules! impl_mul {
|
||||||
|
($($args:tt)*) => {
|
||||||
|
impl_bin_op!(Mul, mul, $($args)*);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implements a + b for all combinations of reference and owned matrices, for
|
||||||
|
/// CsrMatrix or CscMatrix.
|
||||||
|
macro_rules! impl_spmm {
|
||||||
|
($matrix_type:ident, $pattern_fn:expr, $spmm_fn:expr) => {
|
||||||
|
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||||
|
let pattern = $pattern_fn(a.pattern(), b.pattern());
|
||||||
|
let values = vec![T::zero(); pattern.nnz()];
|
||||||
|
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
|
||||||
|
.unwrap();
|
||||||
|
$spmm_fn(T::zero(),
|
||||||
|
&mut result,
|
||||||
|
T::one(),
|
||||||
|
Op::NoOp(a),
|
||||||
|
Op::NoOp(b))
|
||||||
|
.expect("Internal error: spmm failed (please debug).");
|
||||||
|
result
|
||||||
|
});
|
||||||
|
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { a * &b});
|
||||||
|
impl_mul!(<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> { &a * b});
|
||||||
|
impl_mul!(<T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { &a * &b});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_spmm!(CsrMatrix, spmm_csr_pattern, spmm_csr_prealloc);
|
||||||
|
// Need to switch order of operations for CSC pattern
|
||||||
|
impl_spmm!(CscMatrix, spmm_csc_pattern, spmm_csc_prealloc);
|
||||||
|
|
||||||
|
/// Implements Scalar * Matrix operations for *concrete* scalar types. The reason this is necessary
|
||||||
|
/// is that we are not able to implement Mul<Matrix<T>> for all T generically due to orphan rules.
|
||||||
|
macro_rules! impl_concrete_scalar_matrix_mul {
|
||||||
|
($matrix_type:ident, $($scalar_type:ty),*) => {
|
||||||
|
// For each concrete scalar type, forward the implementation of scalar * matrix
|
||||||
|
// to matrix * scalar, which we have already implemented through generics
|
||||||
|
$(
|
||||||
|
impl_mul!(<>(a: $scalar_type, b: $matrix_type<$scalar_type>)
|
||||||
|
-> $matrix_type<$scalar_type> { b * a });
|
||||||
|
impl_mul!(<'a>(a: $scalar_type, b: &'a $matrix_type<$scalar_type>)
|
||||||
|
-> $matrix_type<$scalar_type> { b * a });
|
||||||
|
impl_mul!(<'a>(a: &'a $scalar_type, b: $matrix_type<$scalar_type>)
|
||||||
|
-> $matrix_type<$scalar_type> { b * (*a) });
|
||||||
|
impl_mul!(<'a>(a: &'a $scalar_type, b: &'a $matrix_type<$scalar_type>)
|
||||||
|
-> $matrix_type<$scalar_type> { b * *a });
|
||||||
|
)*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implements multiplication between matrix and scalar for various matrix types
|
||||||
|
macro_rules! impl_scalar_mul {
|
||||||
|
($matrix_type: ident) => {
|
||||||
|
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a T) -> $matrix_type<T> {
|
||||||
|
let values: Vec<_> = a.values()
|
||||||
|
.iter()
|
||||||
|
.map(|v_i| v_i.inlined_clone() * b.inlined_clone())
|
||||||
|
.collect();
|
||||||
|
$matrix_type::try_from_pattern_and_values(a.pattern().clone(), values).unwrap()
|
||||||
|
});
|
||||||
|
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: T) -> $matrix_type<T> {
|
||||||
|
a * &b
|
||||||
|
});
|
||||||
|
impl_mul!(<'a, T>(a: $matrix_type<T>, b: &'a T) -> $matrix_type<T> {
|
||||||
|
let mut a = a;
|
||||||
|
for value in a.values_mut() {
|
||||||
|
*value = b.inlined_clone() * value.inlined_clone();
|
||||||
|
}
|
||||||
|
a
|
||||||
|
});
|
||||||
|
impl_mul!(<T>(a: $matrix_type<T>, b: T) -> $matrix_type<T> {
|
||||||
|
a * &b
|
||||||
|
});
|
||||||
|
impl_concrete_scalar_matrix_mul!(
|
||||||
|
$matrix_type,
|
||||||
|
i8, i16, i32, i64, isize, f32, f64);
|
||||||
|
|
||||||
|
impl<T> MulAssign<T> for $matrix_type<T>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
fn mul_assign(&mut self, scalar: T) {
|
||||||
|
for val in self.values_mut() {
|
||||||
|
*val *= scalar.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> MulAssign<&'a T> for $matrix_type<T>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
fn mul_assign(&mut self, scalar: &'a T) {
|
||||||
|
for val in self.values_mut() {
|
||||||
|
*val *= scalar.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_scalar_mul!(CsrMatrix);
|
||||||
|
impl_scalar_mul!(CscMatrix);
|
||||||
|
|
||||||
|
macro_rules! impl_neg {
|
||||||
|
($matrix_type:ident) => {
|
||||||
|
impl<T> Neg for $matrix_type<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Neg<Output = T>,
|
||||||
|
{
|
||||||
|
type Output = $matrix_type<T>;
|
||||||
|
|
||||||
|
fn neg(mut self) -> Self::Output {
|
||||||
|
for v_i in self.values_mut() {
|
||||||
|
*v_i = -v_i.inlined_clone();
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Neg for &'a $matrix_type<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Neg<Output = T>,
|
||||||
|
{
|
||||||
|
type Output = $matrix_type<T>;
|
||||||
|
|
||||||
|
fn neg(self) -> Self::Output {
|
||||||
|
// TODO: This is inefficient. Ideally we'd have a method that would let us
|
||||||
|
// obtain both the sparsity pattern and values from the matrix,
|
||||||
|
// and then modify the values before creating a new matrix from the pattern
|
||||||
|
// and negated values.
|
||||||
|
-self.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_neg!(CsrMatrix);
|
||||||
|
impl_neg!(CscMatrix);
|
||||||
|
|
||||||
|
macro_rules! impl_div {
|
||||||
|
($matrix_type:ident) => {
|
||||||
|
impl_bin_op!(Div, div, <T: ClosedDiv>(matrix: $matrix_type<T>, scalar: T) -> $matrix_type<T> {
|
||||||
|
let mut matrix = matrix;
|
||||||
|
matrix /= scalar;
|
||||||
|
matrix
|
||||||
|
});
|
||||||
|
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: $matrix_type<T>, scalar: &T) -> $matrix_type<T> {
|
||||||
|
matrix / scalar.inlined_clone()
|
||||||
|
});
|
||||||
|
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: &'a $matrix_type<T>, scalar: T) -> $matrix_type<T> {
|
||||||
|
let new_values = matrix.values()
|
||||||
|
.iter()
|
||||||
|
.map(|v_i| v_i.inlined_clone() / scalar.inlined_clone())
|
||||||
|
.collect();
|
||||||
|
$matrix_type::try_from_pattern_and_values(matrix.pattern().clone(), new_values)
|
||||||
|
.unwrap()
|
||||||
|
});
|
||||||
|
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: &'a $matrix_type<T>, scalar: &'a T) -> $matrix_type<T> {
|
||||||
|
matrix / scalar.inlined_clone()
|
||||||
|
});
|
||||||
|
|
||||||
|
impl<T> DivAssign<T> for $matrix_type<T>
|
||||||
|
where T : Scalar + ClosedAdd + ClosedMul + ClosedDiv + Zero + One
|
||||||
|
{
|
||||||
|
fn div_assign(&mut self, scalar: T) {
|
||||||
|
self.values_mut().iter_mut().for_each(|v_i| *v_i /= scalar.inlined_clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> DivAssign<&'a T> for $matrix_type<T>
|
||||||
|
where T : Scalar + ClosedAdd + ClosedMul + ClosedDiv + Zero + One
|
||||||
|
{
|
||||||
|
fn div_assign(&mut self, scalar: &'a T) {
|
||||||
|
*self /= scalar.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_div!(CsrMatrix);
|
||||||
|
impl_div!(CscMatrix);
|
||||||
|
|
||||||
|
macro_rules! impl_spmm_cs_dense {
|
||||||
|
($matrix_type_name:ident, $spmm_fn:ident) => {
|
||||||
|
// Implement ref-ref
|
||||||
|
impl_spmm_cs_dense!(&'a $matrix_type_name<T>, &'a Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
|
||||||
|
let (_, ncols) = rhs.data.shape();
|
||||||
|
let nrows = Dynamic::new(lhs.nrows());
|
||||||
|
let mut result = MatrixMN::<T, Dynamic, C>::zeros_generic(nrows, ncols);
|
||||||
|
$spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(lhs), Op::NoOp(rhs));
|
||||||
|
result
|
||||||
|
});
|
||||||
|
|
||||||
|
// Implement the other combinations by deferring to ref-ref
|
||||||
|
impl_spmm_cs_dense!(&'a $matrix_type_name<T>, Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
|
||||||
|
lhs * &rhs
|
||||||
|
});
|
||||||
|
impl_spmm_cs_dense!($matrix_type_name<T>, &'a Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
|
||||||
|
&lhs * rhs
|
||||||
|
});
|
||||||
|
impl_spmm_cs_dense!($matrix_type_name<T>, Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
|
||||||
|
&lhs * &rhs
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
// Main body of the macro. The first pattern just forwards to this pattern but with
|
||||||
|
// different arguments
|
||||||
|
($sparse_matrix_type:ty, $dense_matrix_type:ty, $spmm_fn:ident,
|
||||||
|
|$lhs:ident, $rhs:ident| $body:tt) =>
|
||||||
|
{
|
||||||
|
impl<'a, T, R, C, S> Mul<$dense_matrix_type> for $sparse_matrix_type
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedMul + ClosedAdd + ClosedSub + ClosedDiv + Neg + Zero + One,
|
||||||
|
R: Dim,
|
||||||
|
C: Dim,
|
||||||
|
S: Storage<T, R, C>,
|
||||||
|
DefaultAllocator: Allocator<T, Dynamic, C>,
|
||||||
|
// TODO: Is it possible to simplify these bounds?
|
||||||
|
ShapeConstraint:
|
||||||
|
// Bounds so that we can turn MatrixMN<T, Dynamic, C> into a DMatrixSliceMut
|
||||||
|
DimEq<U1, <<DefaultAllocator as Allocator<T, Dynamic, C>>::Buffer as Storage<T, Dynamic, C>>::RStride>
|
||||||
|
+ DimEq<C, Dynamic>
|
||||||
|
+ DimEq<Dynamic, <<DefaultAllocator as Allocator<T, Dynamic, C>>::Buffer as Storage<T, Dynamic, C>>::CStride>
|
||||||
|
// Bounds so that we can turn &Matrix<T, R, C, S> into a DMatrixSlice
|
||||||
|
+ DimEq<U1, S::RStride>
|
||||||
|
+ DimEq<R, Dynamic>
|
||||||
|
+ DimEq<Dynamic, S::CStride>
|
||||||
|
{
|
||||||
|
// We need the column dimension to be generic, so that if RHS is a vector, then
|
||||||
|
// we also get a vector (and not a matrix)
|
||||||
|
type Output = MatrixMN<T, Dynamic, C>;
|
||||||
|
|
||||||
|
fn mul(self, rhs: $dense_matrix_type) -> Self::Output {
|
||||||
|
let $lhs = self;
|
||||||
|
let $rhs = rhs;
|
||||||
|
$body
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_spmm_cs_dense!(CsrMatrix, spmm_csr_dense);
|
||||||
|
impl_spmm_cs_dense!(CscMatrix, spmm_csc_dense);
|
194
nalgebra-sparse/src/ops/mod.rs
Normal file
194
nalgebra-sparse/src/ops/mod.rs
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
//! Sparse matrix arithmetic operations.
|
||||||
|
//!
|
||||||
|
//! This module contains a number of routines for sparse matrix arithmetic. These routines are
|
||||||
|
//! primarily intended for "expert usage". Most users should prefer to use standard
|
||||||
|
//! `std::ops` operations for simple and readable code when possible. The routines provided here
|
||||||
|
//! offer more control over allocation, and allow fusing some low-level operations for higher
|
||||||
|
//! performance.
|
||||||
|
//!
|
||||||
|
//! The available operations are organized by backend. Currently, only the [`serial`] backend
|
||||||
|
//! is available. In the future, backends that expose parallel operations may become available.
|
||||||
|
//! All `std::ops` implementations will remain single-threaded and powered by the
|
||||||
|
//! `serial` backend.
|
||||||
|
//!
|
||||||
|
//! Many routines are able to implicitly transpose matrices involved in the operation.
|
||||||
|
//! For example, the routine [`spadd_csr_prealloc`](serial::spadd_csr_prealloc) performs the
|
||||||
|
//! operation `C <- beta * C + alpha * op(A)`. Here `op(A)` indicates that the matrix `A` can
|
||||||
|
//! either be used as-is or transposed. The notation `op(A)` is represented in code by the
|
||||||
|
//! [`Op`] enum.
|
||||||
|
//!
|
||||||
|
//! # Available `std::ops` implementations
|
||||||
|
//!
|
||||||
|
//! ## Binary operators
|
||||||
|
//!
|
||||||
|
//! The below table summarizes the currently supported binary operators between matrices.
|
||||||
|
//! In general, binary operators between sparse matrices are only supported if both matrices
|
||||||
|
//! are stored in the same format. All supported binary operators are implemented for
|
||||||
|
//! all four combinations of values and references.
|
||||||
|
//!
|
||||||
|
//! <table>
|
||||||
|
//! <tr>
|
||||||
|
//! <th>LHS (down) \ RHS (right)</th>
|
||||||
|
//! <th>COO</th>
|
||||||
|
//! <th>CSR</th>
|
||||||
|
//! <th>CSC</th>
|
||||||
|
//! <th>Dense</th>
|
||||||
|
//! </tr>
|
||||||
|
//! <tr>
|
||||||
|
//! <th>COO</th>
|
||||||
|
//! <td></td>
|
||||||
|
//! <td></td>
|
||||||
|
//! <td></td>
|
||||||
|
//! <td></td>
|
||||||
|
//! </tr>
|
||||||
|
//! <tr>
|
||||||
|
//! <th>CSR</th>
|
||||||
|
//! <td></td>
|
||||||
|
//! <td>+ - *</td>
|
||||||
|
//! <td></td>
|
||||||
|
//! <td>*</td>
|
||||||
|
//! </tr>
|
||||||
|
//! <tr>
|
||||||
|
//! <th>CSC</th>
|
||||||
|
//! <td></td>
|
||||||
|
//! <td></td>
|
||||||
|
//! <td>+ - *</td>
|
||||||
|
//! <td>*</td>
|
||||||
|
//! </tr>
|
||||||
|
//! <tr>
|
||||||
|
//! <th>Dense</th>
|
||||||
|
//! <td></td>
|
||||||
|
//! <td></td>
|
||||||
|
//! <td></td>
|
||||||
|
//! <td>+ - *</td>
|
||||||
|
//! </tr>
|
||||||
|
//! </table>
|
||||||
|
//!
|
||||||
|
//! As can be seen from the table, only `CSR * Dense` and `CSC * Dense` are supported.
|
||||||
|
//! The other way around, i.e. `Dense * CSR` and `Dense * CSC` are not implemented.
|
||||||
|
//!
|
||||||
|
//! Additionally, [CsrMatrix](`crate::csr::CsrMatrix`) and [CooMatrix](`crate::coo::CooMatrix`)
|
||||||
|
//! support multiplication with scalars, in addition to division by a scalar.
|
||||||
|
//! Note that only `Matrix * Scalar` works in a generic context, although `Scalar * Matrix`
|
||||||
|
//! has been implemented for many of the built-in arithmetic types. This is due to a fundamental
|
||||||
|
//! restriction of the Rust type system. Therefore, in generic code you will need to always place
|
||||||
|
//! the matrix on the left-hand side of the multiplication.
|
||||||
|
//!
|
||||||
|
//! ## Unary operators
|
||||||
|
//!
|
||||||
|
//! The following table lists currently supported unary operators.
|
||||||
|
//!
|
||||||
|
//! | Format | AddAssign\<Matrix\> | MulAssign\<Matrix\> | MulAssign\<Scalar\> | Neg |
|
||||||
|
//! | -------- | ----------------- | ----------------- | ------------------- | ------ |
|
||||||
|
//! | COO | | | | |
|
||||||
|
//! | CSR | | | x | x |
|
||||||
|
//! | CSC | | | x | x |
|
||||||
|
//! |
|
||||||
|
//! # Example usage
|
||||||
|
//!
|
||||||
|
//! For example, consider the case where you want to compute the expression
|
||||||
|
//! `C <- 3.0 * C + 2.0 * A^T * B`, where `A`, `B`, `C` are matrices and `A^T` is the transpose
|
||||||
|
//! of `A`. The simplest way to write this is:
|
||||||
|
//!
|
||||||
|
//! ```rust
|
||||||
|
//! # use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
//! # let a = CsrMatrix::identity(10); let b = CsrMatrix::identity(10);
|
||||||
|
//! # let mut c = CsrMatrix::identity(10);
|
||||||
|
//! c = 3.0 * c + 2.0 * a.transpose() * b;
|
||||||
|
//! ```
|
||||||
|
//! This is simple and straightforward to read, and therefore the recommended way to implement
|
||||||
|
//! it. However, if you have determined that this is a performance bottleneck of your application,
|
||||||
|
//! it may be possible to speed things up. First, let's see what's going on here. The `std`
|
||||||
|
//! operations are evaluated eagerly. This means that the following steps take place:
|
||||||
|
//!
|
||||||
|
//! 1. Evaluate `let c_temp = 3.0 * c`. This requires scaling all values of the matrix.
|
||||||
|
//! 2. Evaluate `let a_t = a.transpose()` into a new temporary matrix.
|
||||||
|
//! 3. Evaluate `let a_t_b = a_t * b` into a new temporary matrix.
|
||||||
|
//! 4. Evaluate `let a_t_b_scaled = 2.0 * a_t_b`. This requires scaling all values of the matrix.
|
||||||
|
//! 5. Evaluate `c = c_temp + a_t_b_scaled`.
|
||||||
|
//!
|
||||||
|
//! An alternative way to implement this expression (here using CSR matrices) is:
|
||||||
|
//!
|
||||||
|
//! ```rust
|
||||||
|
//! # use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
//! # let a = CsrMatrix::identity(10); let b = CsrMatrix::identity(10);
|
||||||
|
//! # let mut c = CsrMatrix::identity(10);
|
||||||
|
//! use nalgebra_sparse::ops::{Op, serial::spmm_csr_prealloc};
|
||||||
|
//!
|
||||||
|
//! // Evaluate the expression `c <- 3.0 * c + 2.0 * a^T * b
|
||||||
|
//! spmm_csr_prealloc(3.0, &mut c, 2.0, Op::Transpose(&a), Op::NoOp(&b))
|
||||||
|
//! .expect("We assume that the pattern of C is able to accommodate the result.");
|
||||||
|
//! ```
|
||||||
|
//! Compared to the simpler example, this snippet is harder to read, but it calls a single
|
||||||
|
//! computational kernel that avoids many of the intermediate steps listed out before. Therefore
|
||||||
|
//! directly calling kernels may sometimes lead to better performance. However, this should
|
||||||
|
//! always be verified by performance profiling!
|
||||||
|
|
||||||
|
mod impl_std_ops;
|
||||||
|
pub mod serial;
|
||||||
|
|
||||||
|
/// Determines whether a matrix should be transposed in a given operation.
|
||||||
|
///
|
||||||
|
/// See the [module-level documentation](crate::ops) for the purpose of this enum.
|
||||||
|
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||||
|
pub enum Op<T> {
|
||||||
|
/// Indicates that the matrix should be used as-is.
|
||||||
|
NoOp(T),
|
||||||
|
/// Indicates that the matrix should be transposed.
|
||||||
|
Transpose(T),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Op<T> {
|
||||||
|
/// Returns a reference to the inner value that the operation applies to.
|
||||||
|
pub fn inner_ref(&self) -> &T {
|
||||||
|
self.as_ref().into_inner()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an `Op` applied to a reference of the inner value.
|
||||||
|
pub fn as_ref(&self) -> Op<&T> {
|
||||||
|
match self {
|
||||||
|
Op::NoOp(obj) => Op::NoOp(&obj),
|
||||||
|
Op::Transpose(obj) => Op::Transpose(&obj),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts the underlying data type.
|
||||||
|
pub fn convert<U>(self) -> Op<U>
|
||||||
|
where
|
||||||
|
T: Into<U>,
|
||||||
|
{
|
||||||
|
self.map_same_op(T::into)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Transforms the inner value with the provided function, but preserves the operation.
|
||||||
|
pub fn map_same_op<U, F: FnOnce(T) -> U>(self, f: F) -> Op<U> {
|
||||||
|
match self {
|
||||||
|
Op::NoOp(obj) => Op::NoOp(f(obj)),
|
||||||
|
Op::Transpose(obj) => Op::Transpose(f(obj)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Consumes the `Op` and returns the inner value.
|
||||||
|
pub fn into_inner(self) -> T {
|
||||||
|
match self {
|
||||||
|
Op::NoOp(obj) | Op::Transpose(obj) => obj,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies the transpose operation.
|
||||||
|
///
|
||||||
|
/// This operation follows the usual semantics of transposition. In particular, double
|
||||||
|
/// transposition is equivalent to no transposition.
|
||||||
|
pub fn transposed(self) -> Self {
|
||||||
|
match self {
|
||||||
|
Op::NoOp(obj) => Op::Transpose(obj),
|
||||||
|
Op::Transpose(obj) => Op::NoOp(obj),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> From<T> for Op<T> {
|
||||||
|
fn from(obj: T) -> Self {
|
||||||
|
Self::NoOp(obj)
|
||||||
|
}
|
||||||
|
}
|
186
nalgebra-sparse/src/ops/serial/cs.rs
Normal file
186
nalgebra-sparse/src/ops/serial/cs.rs
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
use crate::cs::CsMatrix;
|
||||||
|
use crate::ops::serial::{OperationError, OperationErrorKind};
|
||||||
|
use crate::ops::Op;
|
||||||
|
use crate::SparseEntryMut;
|
||||||
|
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
|
||||||
|
use num_traits::{One, Zero};
|
||||||
|
|
||||||
|
fn spmm_cs_unexpected_entry() -> OperationError {
|
||||||
|
OperationError::from_kind_and_message(
|
||||||
|
OperationErrorKind::InvalidPattern,
|
||||||
|
String::from("Found unexpected entry that is not present in `c`."),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper functionality for implementing CSR/CSC SPMM.
|
||||||
|
///
|
||||||
|
/// Since CSR/CSC matrices are basically transpositions of each other, which lets us use the same
|
||||||
|
/// algorithm for the SPMM implementation. The implementation here is written in a CSR-centric
|
||||||
|
/// manner. This means that when using it for CSC, the order of the matrices needs to be
|
||||||
|
/// reversed (since transpose(AB) = transpose(B) * transpose(A) and CSC(A) = transpose(CSR(A)).
|
||||||
|
///
|
||||||
|
/// We assume here that the matrices have already been verified to be dimensionally compatible.
|
||||||
|
pub fn spmm_cs_prealloc<T>(
|
||||||
|
beta: T,
|
||||||
|
c: &mut CsMatrix<T>,
|
||||||
|
alpha: T,
|
||||||
|
a: &CsMatrix<T>,
|
||||||
|
b: &CsMatrix<T>,
|
||||||
|
) -> Result<(), OperationError>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
|
{
|
||||||
|
for i in 0..c.pattern().major_dim() {
|
||||||
|
let a_lane_i = a.get_lane(i).unwrap();
|
||||||
|
let mut c_lane_i = c.get_lane_mut(i).unwrap();
|
||||||
|
for c_ij in c_lane_i.values_mut() {
|
||||||
|
*c_ij = beta.inlined_clone() * c_ij.inlined_clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (&k, a_ik) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
|
||||||
|
let b_lane_k = b.get_lane(k).unwrap();
|
||||||
|
let (mut c_lane_i_cols, mut c_lane_i_values) = c_lane_i.indices_and_values_mut();
|
||||||
|
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
|
||||||
|
for (j, b_kj) in b_lane_k.minor_indices().iter().zip(b_lane_k.values()) {
|
||||||
|
// Determine the location in C to append the value
|
||||||
|
let (c_local_idx, _) = c_lane_i_cols
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.find(|(_, c_col)| *c_col == j)
|
||||||
|
.ok_or_else(spmm_cs_unexpected_entry)?;
|
||||||
|
|
||||||
|
c_lane_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone();
|
||||||
|
c_lane_i_cols = &c_lane_i_cols[c_local_idx..];
|
||||||
|
c_lane_i_values = &mut c_lane_i_values[c_local_idx..];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spadd_cs_unexpected_entry() -> OperationError {
|
||||||
|
OperationError::from_kind_and_message(
|
||||||
|
OperationErrorKind::InvalidPattern,
|
||||||
|
String::from("Found entry in `op(a)` that is not present in `c`."),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper functionality for implementing CSR/CSC SPADD.
|
||||||
|
pub fn spadd_cs_prealloc<T>(
|
||||||
|
beta: T,
|
||||||
|
c: &mut CsMatrix<T>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CsMatrix<T>>,
|
||||||
|
) -> Result<(), OperationError>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
|
{
|
||||||
|
match a {
|
||||||
|
Op::NoOp(a) => {
|
||||||
|
for (mut c_lane_i, a_lane_i) in c.lane_iter_mut().zip(a.lane_iter()) {
|
||||||
|
if beta != T::one() {
|
||||||
|
for c_ij in c_lane_i.values_mut() {
|
||||||
|
*c_ij *= beta.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (mut c_minors, mut c_vals) = c_lane_i.indices_and_values_mut();
|
||||||
|
let (a_minors, a_vals) = (a_lane_i.minor_indices(), a_lane_i.values());
|
||||||
|
|
||||||
|
for (a_col, a_val) in a_minors.iter().zip(a_vals) {
|
||||||
|
// TODO: Use exponential search instead of linear search.
|
||||||
|
// If C has substantially more entries in the row than A, then a line search
|
||||||
|
// will needlessly visit many entries in C.
|
||||||
|
let (c_idx, _) = c_minors
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.find(|(_, c_col)| *c_col == a_col)
|
||||||
|
.ok_or_else(spadd_cs_unexpected_entry)?;
|
||||||
|
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
|
||||||
|
c_minors = &c_minors[c_idx..];
|
||||||
|
c_vals = &mut c_vals[c_idx..];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Op::Transpose(a) => {
|
||||||
|
if beta != T::one() {
|
||||||
|
for c_ij in c.values_mut() {
|
||||||
|
*c_ij *= beta.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i, a_lane_i) in a.lane_iter().enumerate() {
|
||||||
|
for (&j, a_val) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
|
||||||
|
let a_val = a_val.inlined_clone();
|
||||||
|
let alpha = alpha.inlined_clone();
|
||||||
|
match c.get_entry_mut(j, i).unwrap() {
|
||||||
|
SparseEntryMut::NonZero(c_ji) => *c_ji += alpha * a_val,
|
||||||
|
SparseEntryMut::Zero => return Err(spadd_cs_unexpected_entry()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper functionality for implementing CSR/CSC SPMM.
|
||||||
|
///
|
||||||
|
/// The implementation essentially assumes that `a` is a CSR matrix. To use it with CSC matrices,
|
||||||
|
/// the transposed operation must be specified for the CSC matrix.
|
||||||
|
pub fn spmm_cs_dense<T>(
|
||||||
|
beta: T,
|
||||||
|
mut c: DMatrixSliceMut<T>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CsMatrix<T>>,
|
||||||
|
b: Op<DMatrixSlice<T>>,
|
||||||
|
) where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
|
{
|
||||||
|
match a {
|
||||||
|
Op::NoOp(a) => {
|
||||||
|
for j in 0..c.ncols() {
|
||||||
|
let mut c_col_j = c.column_mut(j);
|
||||||
|
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.lane_iter()) {
|
||||||
|
let mut dot_ij = T::zero();
|
||||||
|
for (&k, a_ik) in a_row_i.minor_indices().iter().zip(a_row_i.values()) {
|
||||||
|
let b_contrib = match b {
|
||||||
|
Op::NoOp(ref b) => b.index((k, j)),
|
||||||
|
Op::Transpose(ref b) => b.index((j, k)),
|
||||||
|
};
|
||||||
|
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
|
||||||
|
}
|
||||||
|
*c_ij = beta.inlined_clone() * c_ij.inlined_clone()
|
||||||
|
+ alpha.inlined_clone() * dot_ij;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Op::Transpose(a) => {
|
||||||
|
// In this case, we have to pre-multiply C by beta
|
||||||
|
c *= beta;
|
||||||
|
|
||||||
|
for k in 0..a.pattern().major_dim() {
|
||||||
|
let a_row_k = a.get_lane(k).unwrap();
|
||||||
|
for (&i, a_ki) in a_row_k.minor_indices().iter().zip(a_row_k.values()) {
|
||||||
|
let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone();
|
||||||
|
let mut c_row_i = c.row_mut(i);
|
||||||
|
match b {
|
||||||
|
Op::NoOp(ref b) => {
|
||||||
|
let b_row_k = b.row(k);
|
||||||
|
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
|
||||||
|
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Op::Transpose(ref b) => {
|
||||||
|
let b_col_k = b.column(k);
|
||||||
|
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
|
||||||
|
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
255
nalgebra-sparse/src/ops/serial/csc.rs
Normal file
255
nalgebra-sparse/src/ops/serial/csc.rs
Normal file
@ -0,0 +1,255 @@
|
|||||||
|
use crate::csc::CscMatrix;
|
||||||
|
use crate::ops::serial::cs::{spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc};
|
||||||
|
use crate::ops::serial::{OperationError, OperationErrorKind};
|
||||||
|
use crate::ops::Op;
|
||||||
|
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, RealField, Scalar};
|
||||||
|
use num_traits::{One, Zero};
|
||||||
|
|
||||||
|
use std::borrow::Cow;
|
||||||
|
|
||||||
|
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||||
|
pub fn spmm_csc_dense<'a, T>(
|
||||||
|
beta: T,
|
||||||
|
c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CscMatrix<T>>,
|
||||||
|
b: Op<impl Into<DMatrixSlice<'a, T>>>,
|
||||||
|
) where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
|
{
|
||||||
|
let b = b.convert();
|
||||||
|
spmm_csc_dense_(beta, c.into(), alpha, a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spmm_csc_dense_<T>(
|
||||||
|
beta: T,
|
||||||
|
c: DMatrixSliceMut<T>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CscMatrix<T>>,
|
||||||
|
b: Op<DMatrixSlice<T>>,
|
||||||
|
) where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
|
{
|
||||||
|
assert_compatible_spmm_dims!(c, a, b);
|
||||||
|
// Need to interpret matrix as transposed since the spmm_cs_dense function assumes CSR layout
|
||||||
|
let a = a.transposed().map_same_op(|a| &a.cs);
|
||||||
|
spmm_cs_dense(beta, c, alpha, a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sparse matrix addition `C <- beta * C + alpha * op(A)`.
|
||||||
|
///
|
||||||
|
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
|
||||||
|
/// returned.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||||
|
pub fn spadd_csc_prealloc<T>(
|
||||||
|
beta: T,
|
||||||
|
c: &mut CscMatrix<T>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CscMatrix<T>>,
|
||||||
|
) -> Result<(), OperationError>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
|
{
|
||||||
|
assert_compatible_spadd_dims!(c, a);
|
||||||
|
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// If the sparsity pattern of `C` is not able to store the result of the operation,
|
||||||
|
/// an error is returned.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||||
|
pub fn spmm_csc_prealloc<T>(
|
||||||
|
beta: T,
|
||||||
|
c: &mut CscMatrix<T>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CscMatrix<T>>,
|
||||||
|
b: Op<&CscMatrix<T>>,
|
||||||
|
) -> Result<(), OperationError>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
|
{
|
||||||
|
assert_compatible_spmm_dims!(c, a, b);
|
||||||
|
|
||||||
|
use Op::{NoOp, Transpose};
|
||||||
|
|
||||||
|
match (&a, &b) {
|
||||||
|
(NoOp(ref a), NoOp(ref b)) => {
|
||||||
|
// Note: We have to reverse the order for CSC matrices
|
||||||
|
spmm_cs_prealloc(beta, &mut c.cs, alpha, &b.cs, &a.cs)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Currently we handle transposition by explicitly precomputing transposed matrices
|
||||||
|
// and calling the operation again without transposition
|
||||||
|
let a_ref: &CscMatrix<T> = a.inner_ref();
|
||||||
|
let b_ref: &CscMatrix<T> = b.inner_ref();
|
||||||
|
let (a, b) = {
|
||||||
|
use Cow::*;
|
||||||
|
match (&a, &b) {
|
||||||
|
(NoOp(_), NoOp(_)) => unreachable!(),
|
||||||
|
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
|
||||||
|
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
|
||||||
|
(Transpose(ref a), Transpose(ref b)) => {
|
||||||
|
(Owned(a.transpose()), Owned(b.transpose()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
spmm_csc_prealloc(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Solve the lower triangular system `op(L) X = B`.
|
||||||
|
///
|
||||||
|
/// Only the lower triangular part of L is read, and the result is stored in B.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// An error is returned if the system can not be solved due to the matrix being singular.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if `L` is not square, or if `L` and `B` are not dimensionally compatible.
|
||||||
|
pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
|
||||||
|
l: Op<&CscMatrix<T>>,
|
||||||
|
b: impl Into<DMatrixSliceMut<'a, T>>,
|
||||||
|
) -> Result<(), OperationError> {
|
||||||
|
let b = b.into();
|
||||||
|
let l_matrix = l.into_inner();
|
||||||
|
assert_eq!(
|
||||||
|
l_matrix.nrows(),
|
||||||
|
l_matrix.ncols(),
|
||||||
|
"Matrix must be square for triangular solve."
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
l_matrix.nrows(),
|
||||||
|
b.nrows(),
|
||||||
|
"Dimension mismatch in sparse lower triangular solver."
|
||||||
|
);
|
||||||
|
match l {
|
||||||
|
Op::NoOp(a) => spsolve_csc_lower_triangular_no_transpose(a, b),
|
||||||
|
Op::Transpose(a) => spsolve_csc_lower_triangular_transpose(a, b),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
|
||||||
|
l: &CscMatrix<T>,
|
||||||
|
b: DMatrixSliceMut<T>,
|
||||||
|
) -> Result<(), OperationError> {
|
||||||
|
let mut x = b;
|
||||||
|
|
||||||
|
// Solve column-by-column
|
||||||
|
for j in 0..x.ncols() {
|
||||||
|
let mut x_col_j = x.column_mut(j);
|
||||||
|
|
||||||
|
for k in 0..l.ncols() {
|
||||||
|
let l_col_k = l.col(k);
|
||||||
|
|
||||||
|
// Skip entries above the diagonal
|
||||||
|
// TODO: Can use exponential search here to quickly skip entries
|
||||||
|
// (we'd like to avoid using binary search as it's very cache unfriendly
|
||||||
|
// and the matrix might actually *be* lower triangular, which would induce
|
||||||
|
// a severe penalty)
|
||||||
|
let diag_csc_index = l_col_k.row_indices().iter().position(|&i| i == k);
|
||||||
|
if let Some(diag_csc_index) = diag_csc_index {
|
||||||
|
let l_kk = l_col_k.values()[diag_csc_index];
|
||||||
|
|
||||||
|
if l_kk != T::zero() {
|
||||||
|
// Update entry associated with diagonal
|
||||||
|
x_col_j[k] /= l_kk;
|
||||||
|
// Copy value after updating (so we don't run into the borrow checker)
|
||||||
|
let x_kj = x_col_j[k];
|
||||||
|
|
||||||
|
let row_indices = &l_col_k.row_indices()[(diag_csc_index + 1)..];
|
||||||
|
let l_values = &l_col_k.values()[(diag_csc_index + 1)..];
|
||||||
|
|
||||||
|
// Note: The remaining entries are below the diagonal
|
||||||
|
for (&i, l_ik) in row_indices.iter().zip(l_values) {
|
||||||
|
let x_ij = &mut x_col_j[i];
|
||||||
|
*x_ij -= l_ik.inlined_clone() * x_kj;
|
||||||
|
}
|
||||||
|
|
||||||
|
x_col_j[k] = x_kj;
|
||||||
|
} else {
|
||||||
|
return spsolve_encountered_zero_diagonal();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return spsolve_encountered_zero_diagonal();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spsolve_encountered_zero_diagonal() -> Result<(), OperationError> {
|
||||||
|
let message = "Matrix contains at least one diagonal entry that is zero.";
|
||||||
|
Err(OperationError::from_kind_and_message(
|
||||||
|
OperationErrorKind::Singular,
|
||||||
|
String::from(message),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spsolve_csc_lower_triangular_transpose<T: RealField>(
|
||||||
|
l: &CscMatrix<T>,
|
||||||
|
b: DMatrixSliceMut<T>,
|
||||||
|
) -> Result<(), OperationError> {
|
||||||
|
let mut x = b;
|
||||||
|
|
||||||
|
// Solve column-by-column
|
||||||
|
for j in 0..x.ncols() {
|
||||||
|
let mut x_col_j = x.column_mut(j);
|
||||||
|
|
||||||
|
// Due to the transposition, we're essentially solving an upper triangular system,
|
||||||
|
// and the columns in our matrix become rows
|
||||||
|
|
||||||
|
for i in (0..l.ncols()).rev() {
|
||||||
|
let l_col_i = l.col(i);
|
||||||
|
|
||||||
|
// Skip entries above the diagonal
|
||||||
|
// TODO: Can use exponential search here to quickly skip entries
|
||||||
|
let diag_csc_index = l_col_i.row_indices().iter().position(|&k| i == k);
|
||||||
|
if let Some(diag_csc_index) = diag_csc_index {
|
||||||
|
let l_ii = l_col_i.values()[diag_csc_index];
|
||||||
|
|
||||||
|
if l_ii != T::zero() {
|
||||||
|
// // Update entry associated with diagonal
|
||||||
|
// x_col_j[k] /= a_kk;
|
||||||
|
|
||||||
|
// Copy value after updating (so we don't run into the borrow checker)
|
||||||
|
let mut x_ii = x_col_j[i];
|
||||||
|
|
||||||
|
let row_indices = &l_col_i.row_indices()[(diag_csc_index + 1)..];
|
||||||
|
let a_values = &l_col_i.values()[(diag_csc_index + 1)..];
|
||||||
|
|
||||||
|
// Note: The remaining entries are below the diagonal
|
||||||
|
for (&k, &l_ki) in row_indices.iter().zip(a_values) {
|
||||||
|
let x_kj = x_col_j[k];
|
||||||
|
x_ii -= l_ki * x_kj;
|
||||||
|
}
|
||||||
|
|
||||||
|
x_col_j[i] = x_ii / l_ii;
|
||||||
|
} else {
|
||||||
|
return spsolve_encountered_zero_diagonal();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return spsolve_encountered_zero_diagonal();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
106
nalgebra-sparse/src/ops/serial/csr.rs
Normal file
106
nalgebra-sparse/src/ops/serial/csr.rs
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
use crate::csr::CsrMatrix;
|
||||||
|
use crate::ops::serial::cs::{spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc};
|
||||||
|
use crate::ops::serial::OperationError;
|
||||||
|
use crate::ops::Op;
|
||||||
|
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
|
||||||
|
use num_traits::{One, Zero};
|
||||||
|
use std::borrow::Cow;
|
||||||
|
|
||||||
|
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
|
||||||
|
pub fn spmm_csr_dense<'a, T>(
|
||||||
|
beta: T,
|
||||||
|
c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CsrMatrix<T>>,
|
||||||
|
b: Op<impl Into<DMatrixSlice<'a, T>>>,
|
||||||
|
) where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
|
{
|
||||||
|
let b = b.convert();
|
||||||
|
spmm_csr_dense_(beta, c.into(), alpha, a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spmm_csr_dense_<T>(
|
||||||
|
beta: T,
|
||||||
|
c: DMatrixSliceMut<T>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CsrMatrix<T>>,
|
||||||
|
b: Op<DMatrixSlice<T>>,
|
||||||
|
) where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
|
{
|
||||||
|
assert_compatible_spmm_dims!(c, a, b);
|
||||||
|
spmm_cs_dense(beta, c, alpha, a.map_same_op(|a| &a.cs), b)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sparse matrix addition `C <- beta * C + alpha * op(A)`.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
|
||||||
|
/// returned.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||||
|
pub fn spadd_csr_prealloc<T>(
|
||||||
|
beta: T,
|
||||||
|
c: &mut CsrMatrix<T>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CsrMatrix<T>>,
|
||||||
|
) -> Result<(), OperationError>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
|
{
|
||||||
|
assert_compatible_spadd_dims!(c, a);
|
||||||
|
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// If the pattern of `C` is not able to hold the result of the operation, an error is returned.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||||
|
pub fn spmm_csr_prealloc<T>(
|
||||||
|
beta: T,
|
||||||
|
c: &mut CsrMatrix<T>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CsrMatrix<T>>,
|
||||||
|
b: Op<&CsrMatrix<T>>,
|
||||||
|
) -> Result<(), OperationError>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
|
{
|
||||||
|
assert_compatible_spmm_dims!(c, a, b);
|
||||||
|
|
||||||
|
use Op::{NoOp, Transpose};
|
||||||
|
|
||||||
|
match (&a, &b) {
|
||||||
|
(NoOp(ref a), NoOp(ref b)) => spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs),
|
||||||
|
_ => {
|
||||||
|
// Currently we handle transposition by explicitly precomputing transposed matrices
|
||||||
|
// and calling the operation again without transposition
|
||||||
|
// TODO: At least use workspaces to allow control of allocations. Maybe
|
||||||
|
// consider implementing certain patterns (like A^T * B) explicitly
|
||||||
|
let a_ref: &CsrMatrix<T> = a.inner_ref();
|
||||||
|
let b_ref: &CsrMatrix<T> = b.inner_ref();
|
||||||
|
let (a, b) = {
|
||||||
|
use Cow::*;
|
||||||
|
match (&a, &b) {
|
||||||
|
(NoOp(_), NoOp(_)) => unreachable!(),
|
||||||
|
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
|
||||||
|
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
|
||||||
|
(Transpose(ref a), Transpose(ref b)) => {
|
||||||
|
(Owned(a.transpose()), Owned(b.transpose()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
spmm_csr_prealloc(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
124
nalgebra-sparse/src/ops/serial/mod.rs
Normal file
124
nalgebra-sparse/src/ops/serial/mod.rs
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
//! Serial sparse matrix arithmetic routines.
|
||||||
|
//!
|
||||||
|
//! All routines are single-threaded.
|
||||||
|
//!
|
||||||
|
//! Some operations have the `prealloc` suffix. This means that they expect that the sparsity
|
||||||
|
//! pattern of the output matrix has already been pre-allocated: that is, the pattern of the result
|
||||||
|
//! of the operation fits entirely in the output pattern. In the future, there will also be
|
||||||
|
//! some operations which will be able to dynamically adapt the output pattern to fit the
|
||||||
|
//! result, but these have yet to be implemented.
|
||||||
|
|
||||||
|
#[macro_use]
|
||||||
|
macro_rules! assert_compatible_spmm_dims {
|
||||||
|
($c:expr, $a:expr, $b:expr) => {{
|
||||||
|
use crate::ops::Op::{NoOp, Transpose};
|
||||||
|
match (&$a, &$b) {
|
||||||
|
(NoOp(ref a), NoOp(ref b)) => {
|
||||||
|
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||||
|
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
|
||||||
|
assert_eq!(a.ncols(), b.nrows(), "A.ncols() != B.nrows()");
|
||||||
|
}
|
||||||
|
(Transpose(ref a), NoOp(ref b)) => {
|
||||||
|
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||||
|
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
|
||||||
|
assert_eq!(a.nrows(), b.nrows(), "A.nrows() != B.nrows()");
|
||||||
|
}
|
||||||
|
(NoOp(ref a), Transpose(ref b)) => {
|
||||||
|
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||||
|
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
|
||||||
|
assert_eq!(a.ncols(), b.ncols(), "A.ncols() != B.ncols()");
|
||||||
|
}
|
||||||
|
(Transpose(ref a), Transpose(ref b)) => {
|
||||||
|
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||||
|
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
|
||||||
|
assert_eq!(a.nrows(), b.ncols(), "A.nrows() != B.ncols()");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_use]
|
||||||
|
macro_rules! assert_compatible_spadd_dims {
|
||||||
|
($c:expr, $a:expr) => {
|
||||||
|
use crate::ops::Op;
|
||||||
|
match $a {
|
||||||
|
Op::NoOp(a) => {
|
||||||
|
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||||
|
assert_eq!($c.ncols(), a.ncols(), "C.ncols() != A.ncols()");
|
||||||
|
}
|
||||||
|
Op::Transpose(a) => {
|
||||||
|
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||||
|
assert_eq!($c.ncols(), a.nrows(), "C.ncols() != A.nrows()");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
mod cs;
|
||||||
|
mod csc;
|
||||||
|
mod csr;
|
||||||
|
mod pattern;
|
||||||
|
|
||||||
|
pub use csc::*;
|
||||||
|
pub use csr::*;
|
||||||
|
pub use pattern::*;
|
||||||
|
use std::fmt;
|
||||||
|
use std::fmt::Formatter;
|
||||||
|
|
||||||
|
/// A description of the error that occurred during an arithmetic operation.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct OperationError {
|
||||||
|
error_kind: OperationErrorKind,
|
||||||
|
message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The different kinds of operation errors that may occur.
|
||||||
|
#[non_exhaustive]
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum OperationErrorKind {
|
||||||
|
/// Indicates that one or more sparsity patterns involved in the operation violate the
|
||||||
|
/// expectations of the routine.
|
||||||
|
///
|
||||||
|
/// For example, this could indicate that the sparsity pattern of the output is not able to
|
||||||
|
/// contain the result of the operation.
|
||||||
|
InvalidPattern,
|
||||||
|
|
||||||
|
/// Indicates that a matrix is singular when it is expected to be invertible.
|
||||||
|
Singular,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OperationError {
|
||||||
|
fn from_kind_and_message(error_type: OperationErrorKind, message: String) -> Self {
|
||||||
|
Self {
|
||||||
|
error_kind: error_type,
|
||||||
|
message,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The operation error kind.
|
||||||
|
pub fn kind(&self) -> &OperationErrorKind {
|
||||||
|
&self.error_kind
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The underlying error message.
|
||||||
|
pub fn message(&self) -> &str {
|
||||||
|
self.message.as_str()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for OperationError {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "Sparse matrix operation error: ")?;
|
||||||
|
match self.kind() {
|
||||||
|
OperationErrorKind::InvalidPattern => {
|
||||||
|
write!(f, "InvalidPattern")?;
|
||||||
|
}
|
||||||
|
OperationErrorKind::Singular => {
|
||||||
|
write!(f, "Singular")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
write!(f, " Message: {}", self.message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for OperationError {}
|
152
nalgebra-sparse/src/ops/serial/pattern.rs
Normal file
152
nalgebra-sparse/src/ops/serial/pattern.rs
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
use crate::pattern::SparsityPattern;
|
||||||
|
|
||||||
|
use std::iter;
|
||||||
|
|
||||||
|
/// Sparse matrix addition pattern construction, `C <- A + B`.
|
||||||
|
///
|
||||||
|
/// Builds the pattern for `C`, which is able to hold the result of the sum `A + B`.
|
||||||
|
/// The patterns are assumed to have the same major and minor dimensions. In other words,
|
||||||
|
/// both patterns `A` and `B` must both stem from the same kind of compressed matrix:
|
||||||
|
/// CSR or CSC.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the patterns do not have the same major and minor dimensions.
|
||||||
|
pub fn spadd_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||||
|
assert_eq!(
|
||||||
|
a.major_dim(),
|
||||||
|
b.major_dim(),
|
||||||
|
"Patterns must have identical major dimensions."
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
a.minor_dim(),
|
||||||
|
b.minor_dim(),
|
||||||
|
"Patterns must have identical minor dimensions."
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut offsets = Vec::new();
|
||||||
|
let mut indices = Vec::new();
|
||||||
|
offsets.reserve(a.major_dim() + 1);
|
||||||
|
indices.clear();
|
||||||
|
|
||||||
|
offsets.push(0);
|
||||||
|
|
||||||
|
for lane_idx in 0..a.major_dim() {
|
||||||
|
let lane_a = a.lane(lane_idx);
|
||||||
|
let lane_b = b.lane(lane_idx);
|
||||||
|
indices.extend(iterate_union(lane_a, lane_b));
|
||||||
|
offsets.push(indices.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Consider circumventing format checks? (requires unsafe, should benchmark first)
|
||||||
|
SparsityPattern::try_from_offsets_and_indices(a.major_dim(), a.minor_dim(), offsets, indices)
|
||||||
|
.expect("Internal error: Pattern must be valid by definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
||||||
|
///
|
||||||
|
/// Assumes that the sparsity patterns both represent CSC matrices, and the result is also
|
||||||
|
/// represented as the sparsity pattern of a CSC matrix.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the patterns, when interpreted as CSC patterns, are not compatible for
|
||||||
|
/// matrix multiplication.
|
||||||
|
pub fn spmm_csc_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||||
|
// Let C = A * B in CSC format. We note that
|
||||||
|
// C^T = B^T * A^T.
|
||||||
|
// Since the interpretation of a CSC matrix in CSR format represents the transpose of the
|
||||||
|
// matrix in CSR, we can compute C^T in *CSR format* by switching the order of a and b,
|
||||||
|
// which lets us obtain C^T in CSR format. Re-interpreting this as CSC gives us C in CSC format
|
||||||
|
spmm_csr_pattern(b, a)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
||||||
|
///
|
||||||
|
/// Assumes that the sparsity patterns both represent CSR matrices, and the result is also
|
||||||
|
/// represented as the sparsity pattern of a CSR matrix.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the patterns, when interpreted as CSR patterns, are not compatible for
|
||||||
|
/// matrix multiplication.
|
||||||
|
pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||||
|
assert_eq!(
|
||||||
|
a.minor_dim(),
|
||||||
|
b.major_dim(),
|
||||||
|
"a and b must have compatible dimensions"
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut offsets = Vec::new();
|
||||||
|
let mut indices = Vec::new();
|
||||||
|
offsets.push(0);
|
||||||
|
|
||||||
|
// Keep a vector of whether we have visited a particular minor index when working
|
||||||
|
// on a major lane
|
||||||
|
// TODO: Consider using a bitvec or similar here to reduce pressure on memory
|
||||||
|
// (would cut memory use to 1/8, which might help reduce cache misses)
|
||||||
|
let mut visited = vec![false; b.minor_dim()];
|
||||||
|
|
||||||
|
for i in 0..a.major_dim() {
|
||||||
|
let a_lane_i = a.lane(i);
|
||||||
|
let c_lane_i_offset = *offsets.last().unwrap();
|
||||||
|
for &k in a_lane_i {
|
||||||
|
let b_lane_k = b.lane(k);
|
||||||
|
|
||||||
|
for &j in b_lane_k {
|
||||||
|
let have_visited_j = &mut visited[j];
|
||||||
|
if !*have_visited_j {
|
||||||
|
indices.push(j);
|
||||||
|
*have_visited_j = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let c_lane_i = &mut indices[c_lane_i_offset..];
|
||||||
|
c_lane_i.sort_unstable();
|
||||||
|
|
||||||
|
// Reset visits so that visited[j] == false for all j for the next major lane
|
||||||
|
for j in c_lane_i {
|
||||||
|
visited[*j] = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
offsets.push(indices.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
SparsityPattern::try_from_offsets_and_indices(a.major_dim(), b.minor_dim(), offsets, indices)
|
||||||
|
.expect("Internal error: Invalid pattern during matrix multiplication pattern construction")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Iterate over the union of the two sets represented by sorted slices
|
||||||
|
/// (with unique elements)
|
||||||
|
fn iterate_union<'a>(
|
||||||
|
mut sorted_a: &'a [usize],
|
||||||
|
mut sorted_b: &'a [usize],
|
||||||
|
) -> impl Iterator<Item = usize> + 'a {
|
||||||
|
iter::from_fn(move || {
|
||||||
|
if let (Some(a_item), Some(b_item)) = (sorted_a.first(), sorted_b.first()) {
|
||||||
|
let item = if a_item < b_item {
|
||||||
|
sorted_a = &sorted_a[1..];
|
||||||
|
a_item
|
||||||
|
} else if b_item < a_item {
|
||||||
|
sorted_b = &sorted_b[1..];
|
||||||
|
b_item
|
||||||
|
} else {
|
||||||
|
// Both lists contain the same element, advance both slices to avoid
|
||||||
|
// duplicate entries in the result
|
||||||
|
sorted_a = &sorted_a[1..];
|
||||||
|
sorted_b = &sorted_b[1..];
|
||||||
|
a_item
|
||||||
|
};
|
||||||
|
Some(*item)
|
||||||
|
} else if let Some(a_item) = sorted_a.first() {
|
||||||
|
sorted_a = &sorted_a[1..];
|
||||||
|
Some(*a_item)
|
||||||
|
} else if let Some(b_item) = sorted_b.first() {
|
||||||
|
sorted_b = &sorted_b[1..];
|
||||||
|
Some(*b_item)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
393
nalgebra-sparse/src/pattern.rs
Normal file
393
nalgebra-sparse/src/pattern.rs
Normal file
@ -0,0 +1,393 @@
|
|||||||
|
//! Sparsity patterns for CSR and CSC matrices.
|
||||||
|
use crate::cs::transpose_cs;
|
||||||
|
use crate::SparseFormatError;
|
||||||
|
use std::error::Error;
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
/// A representation of the sparsity pattern of a CSR or CSC matrix.
|
||||||
|
///
|
||||||
|
/// CSR and CSC matrices store matrices in a very similar fashion. In fact, in a certain sense,
|
||||||
|
/// they are transposed. More precisely, when reinterpreting the three data arrays of a CSR
|
||||||
|
/// matrix as a CSC matrix, we obtain the CSC representation of its transpose.
|
||||||
|
///
|
||||||
|
/// [`SparsityPattern`] is an abstraction built on this observation. Whereas CSR matrices
|
||||||
|
/// store a matrix row-by-row, and a CSC matrix stores a matrix column-by-column, a
|
||||||
|
/// `SparsityPattern` represents only the index data structure of a matrix *lane-by-lane*.
|
||||||
|
/// Here, a *lane* is a generalization of rows and columns. We further define *major lanes*
|
||||||
|
/// and *minor lanes*. The sparsity pattern of a CSR matrix is then obtained by interpreting
|
||||||
|
/// major/minor as row/column. Conversely, we obtain the sparsity pattern of a CSC matrix by
|
||||||
|
/// interpreting major/minor as column/row.
|
||||||
|
///
|
||||||
|
/// This allows us to use a common abstraction to talk about sparsity patterns of CSR and CSC
|
||||||
|
/// matrices. This is convenient, because at the abstract level, the invariants of the formats
|
||||||
|
/// are the same. Hence we may encode the invariants of the index data structure separately from
|
||||||
|
/// the scalar values of the matrix. This is especially useful in applications where the
|
||||||
|
/// sparsity pattern is built ahead of the matrix values, or the same sparsity pattern is re-used
|
||||||
|
/// between different matrices. Finally, we can use `SparsityPattern` to encode adjacency
|
||||||
|
/// information in graphs.
|
||||||
|
///
|
||||||
|
/// # Format
|
||||||
|
///
|
||||||
|
/// The format is exactly the same as for the index data structures of CSR and CSC matrices.
|
||||||
|
/// This means that the sparsity pattern of an `m x n` sparse matrix with `nnz` non-zeros,
|
||||||
|
/// where in this case `m x n` does *not* mean `rows x columns`, but rather `majors x minors`,
|
||||||
|
/// is represented by the following two arrays:
|
||||||
|
///
|
||||||
|
/// - `major_offsets`, an array of integers with length `m + 1`.
|
||||||
|
/// - `minor_indices`, an array of integers with length `nnz`.
|
||||||
|
///
|
||||||
|
/// The invariants and relationship between `major_offsets` and `minor_indices` remain the same
|
||||||
|
/// as for `row_offsets` and `col_indices` in the [CSR](`crate::csr::CsrMatrix`) format
|
||||||
|
/// specification.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
// TODO: Make SparsityPattern parametrized by index type
|
||||||
|
// (need a solid abstraction for index types though)
|
||||||
|
pub struct SparsityPattern {
|
||||||
|
major_offsets: Vec<usize>,
|
||||||
|
minor_indices: Vec<usize>,
|
||||||
|
minor_dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SparsityPattern {
|
||||||
|
/// Create a sparsity pattern of the given dimensions without explicitly stored entries.
|
||||||
|
pub fn zeros(major_dim: usize, minor_dim: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
major_offsets: vec![0; major_dim + 1],
|
||||||
|
minor_indices: vec![],
|
||||||
|
minor_dim,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The offsets for the major dimension.
|
||||||
|
#[inline]
|
||||||
|
pub fn major_offsets(&self) -> &[usize] {
|
||||||
|
&self.major_offsets
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The indices for the minor dimension.
|
||||||
|
#[inline]
|
||||||
|
pub fn minor_indices(&self) -> &[usize] {
|
||||||
|
&self.minor_indices
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of major lanes in the pattern.
|
||||||
|
#[inline]
|
||||||
|
pub fn major_dim(&self) -> usize {
|
||||||
|
assert!(self.major_offsets.len() > 0);
|
||||||
|
self.major_offsets.len() - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of minor lanes in the pattern.
|
||||||
|
#[inline]
|
||||||
|
pub fn minor_dim(&self) -> usize {
|
||||||
|
self.minor_dim
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of "non-zeros", i.e. explicitly stored entries in the pattern.
|
||||||
|
#[inline]
|
||||||
|
pub fn nnz(&self) -> usize {
|
||||||
|
self.minor_indices.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the lane at the given index.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
///
|
||||||
|
/// Panics if `major_index` is out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn lane(&self, major_index: usize) -> &[usize] {
|
||||||
|
self.get_lane(major_index).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the lane at the given index, or `None` if out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn get_lane(&self, major_index: usize) -> Option<&[usize]> {
|
||||||
|
let offset_begin = *self.major_offsets().get(major_index)?;
|
||||||
|
let offset_end = *self.major_offsets().get(major_index + 1)?;
|
||||||
|
Some(&self.minor_indices()[offset_begin..offset_end])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to construct a sparsity pattern from the given dimensions, major offsets
|
||||||
|
/// and minor indices.
|
||||||
|
///
|
||||||
|
/// Returns an error if the data does not conform to the requirements.
|
||||||
|
pub fn try_from_offsets_and_indices(
|
||||||
|
major_dim: usize,
|
||||||
|
minor_dim: usize,
|
||||||
|
major_offsets: Vec<usize>,
|
||||||
|
minor_indices: Vec<usize>,
|
||||||
|
) -> Result<Self, SparsityPatternFormatError> {
|
||||||
|
use SparsityPatternFormatError::*;
|
||||||
|
|
||||||
|
if major_offsets.len() != major_dim + 1 {
|
||||||
|
return Err(InvalidOffsetArrayLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the first and last offsets conform to the specification
|
||||||
|
{
|
||||||
|
let first_offset_ok = *major_offsets.first().unwrap() == 0;
|
||||||
|
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
|
||||||
|
if !first_offset_ok || !last_offset_ok {
|
||||||
|
return Err(InvalidOffsetFirstLast);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that each lane has strictly monotonically increasing minor indices, i.e.
|
||||||
|
// minor indices within a lane are sorted, unique. In addition, each minor index
|
||||||
|
// must be in bounds with respect to the minor dimension.
|
||||||
|
{
|
||||||
|
for lane_idx in 0..major_dim {
|
||||||
|
let range_start = major_offsets[lane_idx];
|
||||||
|
let range_end = major_offsets[lane_idx + 1];
|
||||||
|
|
||||||
|
// Test that major offsets are monotonically increasing
|
||||||
|
if range_start > range_end {
|
||||||
|
return Err(NonmonotonicOffsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
let minor_indices = &minor_indices[range_start..range_end];
|
||||||
|
|
||||||
|
// We test for in-bounds, uniqueness and monotonicity at the same time
|
||||||
|
// to ensure that we only visit each minor index once
|
||||||
|
let mut iter = minor_indices.iter();
|
||||||
|
let mut prev = None;
|
||||||
|
|
||||||
|
while let Some(next) = iter.next().copied() {
|
||||||
|
if next >= minor_dim {
|
||||||
|
return Err(MinorIndexOutOfBounds);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(prev) = prev {
|
||||||
|
if prev > next {
|
||||||
|
return Err(NonmonotonicMinorIndices);
|
||||||
|
} else if prev == next {
|
||||||
|
return Err(DuplicateEntry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prev = Some(next);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
major_offsets,
|
||||||
|
minor_indices,
|
||||||
|
minor_dim,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An iterator over the explicitly stored "non-zero" entries (i, j).
|
||||||
|
///
|
||||||
|
/// The iteration happens in a lane-major fashion, meaning that the lane index i
|
||||||
|
/// increases monotonically, and the minor index j increases monotonically within each
|
||||||
|
/// lane i.
|
||||||
|
///
|
||||||
|
/// Examples
|
||||||
|
/// --------
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # use nalgebra_sparse::pattern::SparsityPattern;
|
||||||
|
/// let offsets = vec![0, 2, 3, 4];
|
||||||
|
/// let minor_indices = vec![0, 2, 1, 0];
|
||||||
|
/// let pattern = SparsityPattern::try_from_offsets_and_indices(3, 4, offsets, minor_indices)
|
||||||
|
/// .unwrap();
|
||||||
|
///
|
||||||
|
/// let entries: Vec<_> = pattern.entries().collect();
|
||||||
|
/// assert_eq!(entries, vec![(0, 0), (0, 2), (1, 1), (2, 0)]);
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
pub fn entries(&self) -> SparsityPatternIter {
|
||||||
|
SparsityPatternIter::from_pattern(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the raw offset and index data for the sparsity pattern.
|
||||||
|
///
|
||||||
|
/// Examples
|
||||||
|
/// --------
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # use nalgebra_sparse::pattern::SparsityPattern;
|
||||||
|
/// let offsets = vec![0, 2, 3, 4];
|
||||||
|
/// let minor_indices = vec![0, 2, 1, 0];
|
||||||
|
/// let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||||
|
/// 3,
|
||||||
|
/// 4,
|
||||||
|
/// offsets.clone(),
|
||||||
|
/// minor_indices.clone())
|
||||||
|
/// .unwrap();
|
||||||
|
/// let (offsets2, minor_indices2) = pattern.disassemble();
|
||||||
|
/// assert_eq!(offsets2, offsets);
|
||||||
|
/// assert_eq!(minor_indices2, minor_indices);
|
||||||
|
/// ```
|
||||||
|
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>) {
|
||||||
|
(self.major_offsets, self.minor_indices)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes the transpose of the sparsity pattern.
|
||||||
|
///
|
||||||
|
/// This is analogous to matrix transposition, i.e. an entry `(i, j)` becomes `(j, i)` in the
|
||||||
|
/// new pattern.
|
||||||
|
pub fn transpose(&self) -> Self {
|
||||||
|
// By using unit () values, we can use the same routines as for CSR/CSC matrices
|
||||||
|
let values = vec![(); self.nnz()];
|
||||||
|
let (new_offsets, new_indices, _) = transpose_cs(
|
||||||
|
self.major_dim(),
|
||||||
|
self.minor_dim(),
|
||||||
|
self.major_offsets(),
|
||||||
|
self.minor_indices(),
|
||||||
|
&values,
|
||||||
|
);
|
||||||
|
// TODO: Skip checks
|
||||||
|
Self::try_from_offsets_and_indices(
|
||||||
|
self.minor_dim(),
|
||||||
|
self.major_dim(),
|
||||||
|
new_offsets,
|
||||||
|
new_indices,
|
||||||
|
)
|
||||||
|
.expect("Internal error: Transpose should never fail.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error type for `SparsityPattern` format errors.
|
||||||
|
#[non_exhaustive]
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub enum SparsityPatternFormatError {
|
||||||
|
/// Indicates an invalid number of offsets.
|
||||||
|
///
|
||||||
|
/// The number of offsets must be equal to (major_dim + 1).
|
||||||
|
InvalidOffsetArrayLength,
|
||||||
|
/// Indicates that the first or last entry in the offset array did not conform to
|
||||||
|
/// specifications.
|
||||||
|
///
|
||||||
|
/// The first entry must be 0, and the last entry must be exactly one greater than the
|
||||||
|
/// major dimension.
|
||||||
|
InvalidOffsetFirstLast,
|
||||||
|
/// Indicates that the major offsets are not monotonically increasing.
|
||||||
|
NonmonotonicOffsets,
|
||||||
|
/// One or more minor indices are out of bounds.
|
||||||
|
MinorIndexOutOfBounds,
|
||||||
|
/// One or more duplicate entries were detected.
|
||||||
|
///
|
||||||
|
/// Two entries are considered duplicates if they are part of the same major lane and have
|
||||||
|
/// the same minor index.
|
||||||
|
DuplicateEntry,
|
||||||
|
/// Indicates that minor indices are not monotonically increasing within each lane.
|
||||||
|
NonmonotonicMinorIndices,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<SparsityPatternFormatError> for SparseFormatError {
|
||||||
|
fn from(err: SparsityPatternFormatError) -> Self {
|
||||||
|
use crate::SparseFormatErrorKind;
|
||||||
|
use crate::SparseFormatErrorKind::*;
|
||||||
|
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||||
|
use SparsityPatternFormatError::*;
|
||||||
|
match err {
|
||||||
|
InvalidOffsetArrayLength
|
||||||
|
| InvalidOffsetFirstLast
|
||||||
|
| NonmonotonicOffsets
|
||||||
|
| NonmonotonicMinorIndices => {
|
||||||
|
SparseFormatError::from_kind_and_error(InvalidStructure, Box::from(err))
|
||||||
|
}
|
||||||
|
MinorIndexOutOfBounds => {
|
||||||
|
SparseFormatError::from_kind_and_error(IndexOutOfBounds, Box::from(err))
|
||||||
|
}
|
||||||
|
PatternDuplicateEntry => SparseFormatError::from_kind_and_error(
|
||||||
|
#[allow(unused_qualifications)]
|
||||||
|
SparseFormatErrorKind::DuplicateEntry,
|
||||||
|
Box::from(err),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for SparsityPatternFormatError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
SparsityPatternFormatError::InvalidOffsetArrayLength => {
|
||||||
|
write!(f, "Length of offset array is not equal to (major_dim + 1).")
|
||||||
|
}
|
||||||
|
SparsityPatternFormatError::InvalidOffsetFirstLast => {
|
||||||
|
write!(f, "First or last offset is incompatible with format.")
|
||||||
|
}
|
||||||
|
SparsityPatternFormatError::NonmonotonicOffsets => {
|
||||||
|
write!(f, "Offsets are not monotonically increasing.")
|
||||||
|
}
|
||||||
|
SparsityPatternFormatError::MinorIndexOutOfBounds => {
|
||||||
|
write!(f, "A minor index is out of bounds.")
|
||||||
|
}
|
||||||
|
SparsityPatternFormatError::DuplicateEntry => {
|
||||||
|
write!(f, "Input data contains duplicate entries.")
|
||||||
|
}
|
||||||
|
SparsityPatternFormatError::NonmonotonicMinorIndices => {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"Minor indices are not monotonically increasing within each lane."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for SparsityPatternFormatError {}
|
||||||
|
|
||||||
|
/// Iterator type for iterating over entries in a sparsity pattern.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct SparsityPatternIter<'a> {
|
||||||
|
// See implementation of Iterator::next for an explanation of how these members are used
|
||||||
|
major_offsets: &'a [usize],
|
||||||
|
minor_indices: &'a [usize],
|
||||||
|
current_lane_idx: usize,
|
||||||
|
remaining_minors_in_lane: &'a [usize],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> SparsityPatternIter<'a> {
|
||||||
|
fn from_pattern(pattern: &'a SparsityPattern) -> Self {
|
||||||
|
let first_lane_end = pattern.major_offsets().get(1).unwrap_or(&0);
|
||||||
|
let minors_in_first_lane = &pattern.minor_indices()[0..*first_lane_end];
|
||||||
|
Self {
|
||||||
|
major_offsets: pattern.major_offsets(),
|
||||||
|
minor_indices: pattern.minor_indices(),
|
||||||
|
current_lane_idx: 0,
|
||||||
|
remaining_minors_in_lane: minors_in_first_lane,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Iterator for SparsityPatternIter<'a> {
|
||||||
|
type Item = (usize, usize);
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
// We ensure fast iteration across each lane by iteratively "draining" a slice
|
||||||
|
// corresponding to the remaining column indices in the particular lane.
|
||||||
|
// When we reach the end of this slice, we are at the end of a lane,
|
||||||
|
// and we must do some bookkeeping for preparing the iteration of the next lane
|
||||||
|
// (or stop iteration if we're through all lanes).
|
||||||
|
// This way we can avoid doing unnecessary bookkeeping on every iteration,
|
||||||
|
// instead paying a small price whenever we jump to a new lane.
|
||||||
|
if let Some(minor_idx) = self.remaining_minors_in_lane.first() {
|
||||||
|
let item = Some((self.current_lane_idx, *minor_idx));
|
||||||
|
self.remaining_minors_in_lane = &self.remaining_minors_in_lane[1..];
|
||||||
|
item
|
||||||
|
} else {
|
||||||
|
loop {
|
||||||
|
// Keep skipping lanes until we found a non-empty lane or there are no more lanes
|
||||||
|
if self.current_lane_idx + 2 >= self.major_offsets.len() {
|
||||||
|
// We've processed all lanes, so we're at the end of the iterator
|
||||||
|
// (note: keep in mind that offsets.len() == major_dim() + 1, hence we need +2)
|
||||||
|
return None;
|
||||||
|
} else {
|
||||||
|
// Bump lane index and check if the lane is non-empty
|
||||||
|
self.current_lane_idx += 1;
|
||||||
|
let lower = self.major_offsets[self.current_lane_idx];
|
||||||
|
let upper = self.major_offsets[self.current_lane_idx + 1];
|
||||||
|
if upper > lower {
|
||||||
|
self.remaining_minors_in_lane = &self.minor_indices[(lower + 1)..upper];
|
||||||
|
return Some((self.current_lane_idx, self.minor_indices[lower]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
374
nalgebra-sparse/src/proptest.rs
Normal file
374
nalgebra-sparse/src/proptest.rs
Normal file
@ -0,0 +1,374 @@
|
|||||||
|
//! Functionality for integrating `nalgebra-sparse` with `proptest`.
|
||||||
|
//!
|
||||||
|
//! **This module is only available if the `proptest-support` feature is enabled**.
|
||||||
|
//!
|
||||||
|
//! The strategies provided here are generally expected to be able to generate the entire range
|
||||||
|
//! of possible outputs given the constraints on dimensions and values. However, there are no
|
||||||
|
//! particular guarantees on the distribution of possible values.
|
||||||
|
|
||||||
|
// Contains some patched code from proptest that we can remove in the (hopefully near) future.
|
||||||
|
// See docs in file for more details.
|
||||||
|
mod proptest_patched;
|
||||||
|
|
||||||
|
use crate::coo::CooMatrix;
|
||||||
|
use crate::csc::CscMatrix;
|
||||||
|
use crate::csr::CsrMatrix;
|
||||||
|
use crate::pattern::SparsityPattern;
|
||||||
|
use nalgebra::proptest::DimRange;
|
||||||
|
use nalgebra::{Dim, Scalar};
|
||||||
|
use proptest::collection::{btree_set, hash_map, vec};
|
||||||
|
use proptest::prelude::*;
|
||||||
|
use proptest::sample::Index;
|
||||||
|
use std::cmp::min;
|
||||||
|
use std::iter::repeat;
|
||||||
|
|
||||||
|
fn dense_row_major_coord_strategy(
|
||||||
|
nrows: usize,
|
||||||
|
ncols: usize,
|
||||||
|
nnz: usize,
|
||||||
|
) -> impl Strategy<Value = Vec<(usize, usize)>> {
|
||||||
|
assert!(nnz <= nrows * ncols);
|
||||||
|
let mut booleans = vec![true; nnz];
|
||||||
|
booleans.append(&mut vec![false; (nrows * ncols) - nnz]);
|
||||||
|
// Make sure that exactly `nnz` of the booleans are true
|
||||||
|
|
||||||
|
// TODO: We cannot use the below code because of a bug in proptest, see
|
||||||
|
// https://github.com/AltSysrq/proptest/pull/217
|
||||||
|
// so for now we're using a patched version of the Shuffle adapter
|
||||||
|
// (see also docs in `proptest_patched`
|
||||||
|
// Just(booleans)
|
||||||
|
// // Need to shuffle to make sure they are randomly distributed
|
||||||
|
// .prop_shuffle()
|
||||||
|
|
||||||
|
proptest_patched::Shuffle(Just(booleans)).prop_map(move |booleans| {
|
||||||
|
booleans
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(index, is_entry)| {
|
||||||
|
if is_entry {
|
||||||
|
// Convert linear index to row/col pair
|
||||||
|
let i = index / ncols;
|
||||||
|
let j = index % ncols;
|
||||||
|
Some((i, j))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A strategy for generating `nnz` triplets.
|
||||||
|
///
|
||||||
|
/// This strategy should generally only be used when `nnz` is close to `nrows * ncols`.
|
||||||
|
fn dense_triplet_strategy<T>(
|
||||||
|
value_strategy: T,
|
||||||
|
nrows: usize,
|
||||||
|
ncols: usize,
|
||||||
|
nnz: usize,
|
||||||
|
) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
|
||||||
|
where
|
||||||
|
T: Strategy + Clone + 'static,
|
||||||
|
T::Value: Scalar,
|
||||||
|
{
|
||||||
|
assert!(nnz <= nrows * ncols);
|
||||||
|
|
||||||
|
// Construct a number of booleans of which exactly `nnz` are true.
|
||||||
|
let booleans: Vec<_> = repeat(true)
|
||||||
|
.take(nnz)
|
||||||
|
.chain(repeat(false))
|
||||||
|
.take(nrows * ncols)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Just(booleans)
|
||||||
|
// Shuffle the booleans so that they are randomly distributed
|
||||||
|
.prop_shuffle()
|
||||||
|
// Convert the booleans into a list of coordinate pairs
|
||||||
|
.prop_map(move |booleans| {
|
||||||
|
booleans
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(index, is_entry)| {
|
||||||
|
if is_entry {
|
||||||
|
// Convert linear index to row/col pair
|
||||||
|
let i = index / ncols;
|
||||||
|
let j = index % ncols;
|
||||||
|
Some((i, j))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
// Assign values to each coordinate pair in order to generate a list of triplets
|
||||||
|
.prop_flat_map(move |coords| {
|
||||||
|
vec![value_strategy.clone(); coords.len()].prop_map(move |values| {
|
||||||
|
coords
|
||||||
|
.clone()
|
||||||
|
.into_iter()
|
||||||
|
.zip(values)
|
||||||
|
.map(|((i, j), v)| (i, j, v))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A strategy for generating `nnz` triplets.
|
||||||
|
///
|
||||||
|
/// This strategy should generally only be used when `nnz << nrows * ncols`. If `nnz` is too
|
||||||
|
/// close to `nrows * ncols` it may fail due to excessive rejected samples.
|
||||||
|
fn sparse_triplet_strategy<T>(
|
||||||
|
value_strategy: T,
|
||||||
|
nrows: usize,
|
||||||
|
ncols: usize,
|
||||||
|
nnz: usize,
|
||||||
|
) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
|
||||||
|
where
|
||||||
|
T: Strategy + Clone + 'static,
|
||||||
|
T::Value: Scalar,
|
||||||
|
{
|
||||||
|
// Have to handle the zero case: proptest doesn't like empty ranges (i.e. 0 .. 0)
|
||||||
|
let row_index_strategy = if nrows > 0 { 0..nrows } else { 0..1 };
|
||||||
|
let col_index_strategy = if ncols > 0 { 0..ncols } else { 0..1 };
|
||||||
|
let coord_strategy = (row_index_strategy, col_index_strategy);
|
||||||
|
hash_map(coord_strategy, value_strategy.clone(), nnz)
|
||||||
|
.prop_map(|hash_map| {
|
||||||
|
let triplets: Vec<_> = hash_map.into_iter().map(|((i, j), v)| (i, j, v)).collect();
|
||||||
|
triplets
|
||||||
|
})
|
||||||
|
// Although order in the hash map is unspecified, it's not necessarily *random*
|
||||||
|
// - or, in particular, it does not necessarily sample the whole space of possible outcomes -
|
||||||
|
// so we additionally shuffle the triplets
|
||||||
|
.prop_shuffle()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A strategy for producing COO matrices without duplicate entries.
|
||||||
|
///
|
||||||
|
/// The values of the matrix are picked from the provided `value_strategy`, while the size of the
|
||||||
|
/// generated matrices is determined by the ranges `rows` and `cols`. The number of explicitly
|
||||||
|
/// stored entries is bounded from above by `max_nonzeros`. Note that the matrix might still
|
||||||
|
/// contain explicitly stored zeroes if the value strategy is capable of generating zero values.
|
||||||
|
pub fn coo_no_duplicates<T>(
|
||||||
|
value_strategy: T,
|
||||||
|
rows: impl Into<DimRange>,
|
||||||
|
cols: impl Into<DimRange>,
|
||||||
|
max_nonzeros: usize,
|
||||||
|
) -> impl Strategy<Value = CooMatrix<T::Value>>
|
||||||
|
where
|
||||||
|
T: Strategy + Clone + 'static,
|
||||||
|
T::Value: Scalar,
|
||||||
|
{
|
||||||
|
(
|
||||||
|
rows.into().to_range_inclusive(),
|
||||||
|
cols.into().to_range_inclusive(),
|
||||||
|
)
|
||||||
|
.prop_flat_map(move |(nrows, ncols)| {
|
||||||
|
let max_nonzeros = min(max_nonzeros, nrows * ncols);
|
||||||
|
let size_range = 0..=max_nonzeros;
|
||||||
|
let value_strategy = value_strategy.clone();
|
||||||
|
|
||||||
|
size_range
|
||||||
|
.prop_flat_map(move |nnz| {
|
||||||
|
let value_strategy = value_strategy.clone();
|
||||||
|
if nnz as f64 > 0.10 * (nrows as f64) * (ncols as f64) {
|
||||||
|
// If the number of nnz is sufficiently dense, then use the dense
|
||||||
|
// sample strategy
|
||||||
|
dense_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
|
||||||
|
} else {
|
||||||
|
// Otherwise, use a hash map strategy so that we can get a sparse sampling
|
||||||
|
// (so that complexity is rather on the order of max_nnz than nrows * ncols)
|
||||||
|
sparse_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.prop_map(move |triplets| {
|
||||||
|
let mut coo = CooMatrix::new(nrows, ncols);
|
||||||
|
for (i, j, v) in triplets {
|
||||||
|
coo.push(i, j, v);
|
||||||
|
}
|
||||||
|
coo
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A strategy for producing COO matrices with duplicate entries.
|
||||||
|
///
|
||||||
|
/// The values of the matrix are picked from the provided `value_strategy`, while the size of the
|
||||||
|
/// generated matrices is determined by the ranges `rows` and `cols`. Note that the values
|
||||||
|
/// only apply to individual entries, and since this strategy can generate duplicate entries,
|
||||||
|
/// the matrix will generally have values outside the range determined by `value_strategy` when
|
||||||
|
/// converted to other formats, since the duplicate entries are summed together in this case.
|
||||||
|
///
|
||||||
|
/// The number of explicitly stored entries is bounded from above by `max_nonzeros`. The maximum
|
||||||
|
/// number of duplicate entries is determined by `max_duplicates`. Note that the matrix might still
|
||||||
|
/// contain explicitly stored zeroes if the value strategy is capable of generating zero values.
|
||||||
|
pub fn coo_with_duplicates<T>(
|
||||||
|
value_strategy: T,
|
||||||
|
rows: impl Into<DimRange>,
|
||||||
|
cols: impl Into<DimRange>,
|
||||||
|
max_nonzeros: usize,
|
||||||
|
max_duplicates: usize,
|
||||||
|
) -> impl Strategy<Value = CooMatrix<T::Value>>
|
||||||
|
where
|
||||||
|
T: Strategy + Clone + 'static,
|
||||||
|
T::Value: Scalar,
|
||||||
|
{
|
||||||
|
let coo_strategy = coo_no_duplicates(value_strategy.clone(), rows, cols, max_nonzeros);
|
||||||
|
let duplicate_strategy = vec((any::<Index>(), value_strategy.clone()), 0..=max_duplicates);
|
||||||
|
(coo_strategy, duplicate_strategy)
|
||||||
|
.prop_flat_map(|(coo, duplicates)| {
|
||||||
|
let mut triplets: Vec<(usize, usize, T::Value)> = coo
|
||||||
|
.triplet_iter()
|
||||||
|
.map(|(i, j, v)| (i, j, v.clone()))
|
||||||
|
.collect();
|
||||||
|
if !triplets.is_empty() {
|
||||||
|
let duplicates_iter: Vec<_> = duplicates
|
||||||
|
.into_iter()
|
||||||
|
.map(|(idx, val)| {
|
||||||
|
let (i, j, _) = idx.get(&triplets);
|
||||||
|
(*i, *j, val)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
triplets.extend(duplicates_iter);
|
||||||
|
}
|
||||||
|
// Make sure to shuffle so that the duplicates get mixed in with the non-duplicates
|
||||||
|
let shuffled = Just(triplets).prop_shuffle();
|
||||||
|
(Just(coo.nrows()), Just(coo.ncols()), shuffled)
|
||||||
|
})
|
||||||
|
.prop_map(move |(nrows, ncols, triplets)| {
|
||||||
|
let mut coo = CooMatrix::new(nrows, ncols);
|
||||||
|
for (i, j, v) in triplets {
|
||||||
|
coo.push(i, j, v);
|
||||||
|
}
|
||||||
|
coo
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sparsity_pattern_from_row_major_coords<I>(
|
||||||
|
nmajor: usize,
|
||||||
|
nminor: usize,
|
||||||
|
coords: I,
|
||||||
|
) -> SparsityPattern
|
||||||
|
where
|
||||||
|
I: Iterator<Item = (usize, usize)> + ExactSizeIterator,
|
||||||
|
{
|
||||||
|
let mut minors = Vec::with_capacity(coords.len());
|
||||||
|
let mut offsets = Vec::with_capacity(nmajor + 1);
|
||||||
|
let mut current_major = 0;
|
||||||
|
offsets.push(0);
|
||||||
|
for (idx, (i, j)) in coords.enumerate() {
|
||||||
|
assert!(i >= current_major);
|
||||||
|
assert!(
|
||||||
|
i < nmajor && j < nminor,
|
||||||
|
"Generated coords are out of bounds"
|
||||||
|
);
|
||||||
|
while current_major < i {
|
||||||
|
offsets.push(idx);
|
||||||
|
current_major += 1;
|
||||||
|
}
|
||||||
|
minors.push(j);
|
||||||
|
}
|
||||||
|
|
||||||
|
while current_major < nmajor {
|
||||||
|
offsets.push(minors.len());
|
||||||
|
current_major += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(offsets.first().unwrap(), &0);
|
||||||
|
assert_eq!(offsets.len(), nmajor + 1);
|
||||||
|
|
||||||
|
SparsityPattern::try_from_offsets_and_indices(nmajor, nminor, offsets, minors)
|
||||||
|
.expect("Internal error: Generated sparsity pattern is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A strategy for generating sparsity patterns.
|
||||||
|
pub fn sparsity_pattern(
|
||||||
|
major_lanes: impl Into<DimRange>,
|
||||||
|
minor_lanes: impl Into<DimRange>,
|
||||||
|
max_nonzeros: usize,
|
||||||
|
) -> impl Strategy<Value = SparsityPattern> {
|
||||||
|
(
|
||||||
|
major_lanes.into().to_range_inclusive(),
|
||||||
|
minor_lanes.into().to_range_inclusive(),
|
||||||
|
)
|
||||||
|
.prop_flat_map(move |(nmajor, nminor)| {
|
||||||
|
let max_nonzeros = min(nmajor * nminor, max_nonzeros);
|
||||||
|
(Just(nmajor), Just(nminor), 0..=max_nonzeros)
|
||||||
|
})
|
||||||
|
.prop_flat_map(move |(nmajor, nminor, nnz)| {
|
||||||
|
if 10 * nnz < nmajor * nminor {
|
||||||
|
// If nnz is small compared to a dense matrix, then use a sparse sampling strategy
|
||||||
|
btree_set((0..nmajor, 0..nminor), nnz)
|
||||||
|
.prop_map(move |coords| {
|
||||||
|
sparsity_pattern_from_row_major_coords(nmajor, nminor, coords.into_iter())
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
|
} else {
|
||||||
|
// If the required number of nonzeros is sufficiently dense,
|
||||||
|
// we instead use a dense sampling
|
||||||
|
dense_row_major_coord_strategy(nmajor, nminor, nnz)
|
||||||
|
.prop_map(move |coords| {
|
||||||
|
let coords = coords.into_iter();
|
||||||
|
sparsity_pattern_from_row_major_coords(nmajor, nminor, coords)
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A strategy for generating CSR matrices.
|
||||||
|
pub fn csr<T>(
|
||||||
|
value_strategy: T,
|
||||||
|
rows: impl Into<DimRange>,
|
||||||
|
cols: impl Into<DimRange>,
|
||||||
|
max_nonzeros: usize,
|
||||||
|
) -> impl Strategy<Value = CsrMatrix<T::Value>>
|
||||||
|
where
|
||||||
|
T: Strategy + Clone + 'static,
|
||||||
|
T::Value: Scalar,
|
||||||
|
{
|
||||||
|
let rows = rows.into();
|
||||||
|
let cols = cols.into();
|
||||||
|
sparsity_pattern(
|
||||||
|
rows.lower_bound().value()..=rows.upper_bound().value(),
|
||||||
|
cols.lower_bound().value()..=cols.upper_bound().value(),
|
||||||
|
max_nonzeros,
|
||||||
|
)
|
||||||
|
.prop_flat_map(move |pattern| {
|
||||||
|
let nnz = pattern.nnz();
|
||||||
|
let values = vec![value_strategy.clone(); nnz];
|
||||||
|
(Just(pattern), values)
|
||||||
|
})
|
||||||
|
.prop_map(|(pattern, values)| {
|
||||||
|
CsrMatrix::try_from_pattern_and_values(pattern, values)
|
||||||
|
.expect("Internal error: Generated CsrMatrix is invalid")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A strategy for generating CSC matrices.
|
||||||
|
pub fn csc<T>(
|
||||||
|
value_strategy: T,
|
||||||
|
rows: impl Into<DimRange>,
|
||||||
|
cols: impl Into<DimRange>,
|
||||||
|
max_nonzeros: usize,
|
||||||
|
) -> impl Strategy<Value = CscMatrix<T::Value>>
|
||||||
|
where
|
||||||
|
T: Strategy + Clone + 'static,
|
||||||
|
T::Value: Scalar,
|
||||||
|
{
|
||||||
|
let rows = rows.into();
|
||||||
|
let cols = cols.into();
|
||||||
|
sparsity_pattern(
|
||||||
|
cols.lower_bound().value()..=cols.upper_bound().value(),
|
||||||
|
rows.lower_bound().value()..=rows.upper_bound().value(),
|
||||||
|
max_nonzeros,
|
||||||
|
)
|
||||||
|
.prop_flat_map(move |pattern| {
|
||||||
|
let nnz = pattern.nnz();
|
||||||
|
let values = vec![value_strategy.clone(); nnz];
|
||||||
|
(Just(pattern), values)
|
||||||
|
})
|
||||||
|
.prop_map(|(pattern, values)| {
|
||||||
|
CscMatrix::try_from_pattern_and_values(pattern, values)
|
||||||
|
.expect("Internal error: Generated CscMatrix is invalid")
|
||||||
|
})
|
||||||
|
}
|
146
nalgebra-sparse/src/proptest/proptest_patched.rs
Normal file
146
nalgebra-sparse/src/proptest/proptest_patched.rs
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
//! Contains a modified implementation of `proptest::strategy::Shuffle`.
|
||||||
|
//!
|
||||||
|
//! The current implementation in `proptest` does not generate all permutations, which is
|
||||||
|
//! problematic for our proptest generators. The issue has been fixed in
|
||||||
|
//! https://github.com/AltSysrq/proptest/pull/217
|
||||||
|
//! but it has yet to be merged and released. As soon as this fix makes it into a new release,
|
||||||
|
//! the modified code here can be removed.
|
||||||
|
//!
|
||||||
|
/*!
|
||||||
|
This code has been copied and adapted from
|
||||||
|
https://github.com/AltSysrq/proptest/blob/master/proptest/src/strategy/shuffle.rs
|
||||||
|
The original licensing text is:
|
||||||
|
|
||||||
|
//-
|
||||||
|
// Copyright 2017 Jason Lingle
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
|
||||||
|
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
|
||||||
|
// option. This file may not be copied, modified, or distributed
|
||||||
|
// except according to those terms.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
use proptest::num;
|
||||||
|
use proptest::prelude::Rng;
|
||||||
|
use proptest::strategy::{NewTree, Shuffleable, Strategy, ValueTree};
|
||||||
|
use proptest::test_runner::{TestRng, TestRunner};
|
||||||
|
use std::cell::Cell;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
#[must_use = "strategies do nothing unless used"]
|
||||||
|
pub struct Shuffle<S>(pub(super) S);
|
||||||
|
|
||||||
|
impl<S: Strategy> Strategy for Shuffle<S>
|
||||||
|
where
|
||||||
|
S::Value: Shuffleable,
|
||||||
|
{
|
||||||
|
type Tree = ShuffleValueTree<S::Tree>;
|
||||||
|
type Value = S::Value;
|
||||||
|
|
||||||
|
fn new_tree(&self, runner: &mut TestRunner) -> NewTree<Self> {
|
||||||
|
let rng = runner.new_rng();
|
||||||
|
|
||||||
|
self.0.new_tree(runner).map(|inner| ShuffleValueTree {
|
||||||
|
inner,
|
||||||
|
rng,
|
||||||
|
dist: Cell::new(None),
|
||||||
|
simplifying_inner: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ShuffleValueTree<V> {
|
||||||
|
inner: V,
|
||||||
|
rng: TestRng,
|
||||||
|
dist: Cell<Option<num::usize::BinarySearch>>,
|
||||||
|
simplifying_inner: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V: ValueTree> ShuffleValueTree<V>
|
||||||
|
where
|
||||||
|
V::Value: Shuffleable,
|
||||||
|
{
|
||||||
|
fn init_dist(&self, dflt: usize) -> usize {
|
||||||
|
if self.dist.get().is_none() {
|
||||||
|
self.dist.set(Some(num::usize::BinarySearch::new(dflt)));
|
||||||
|
}
|
||||||
|
|
||||||
|
self.dist.get().unwrap().current()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn force_init_dist(&self) {
|
||||||
|
if self.dist.get().is_none() {
|
||||||
|
let _ = self.init_dist(self.current().shuffle_len());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V: ValueTree> ValueTree for ShuffleValueTree<V>
|
||||||
|
where
|
||||||
|
V::Value: Shuffleable,
|
||||||
|
{
|
||||||
|
type Value = V::Value;
|
||||||
|
|
||||||
|
fn current(&self) -> V::Value {
|
||||||
|
let mut value = self.inner.current();
|
||||||
|
let len = value.shuffle_len();
|
||||||
|
// The maximum distance to swap elements. This could be larger than
|
||||||
|
// `value` if `value` has reduced size during shrinking; that's OK,
|
||||||
|
// since we only use this to filter swaps.
|
||||||
|
let max_swap = self.init_dist(len);
|
||||||
|
|
||||||
|
// If empty collection or all swaps will be filtered out, there's
|
||||||
|
// nothing to shuffle.
|
||||||
|
if 0 == len || 0 == max_swap {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut rng = self.rng.clone();
|
||||||
|
|
||||||
|
for start_index in 0..len - 1 {
|
||||||
|
// Determine the other index to be swapped, then skip the swap if
|
||||||
|
// it is too far. This ordering is critical, as it ensures that we
|
||||||
|
// generate the same sequence of random numbers every time.
|
||||||
|
|
||||||
|
// NOTE: The below line is the whole reason for the existence of this adapted code
|
||||||
|
// We need to be able to swap with the same element, so that some elements remain in
|
||||||
|
// place rather being swapped
|
||||||
|
// let end_index = rng.gen_range(start_index + 1, len);
|
||||||
|
let end_index = rng.gen_range(start_index, len);
|
||||||
|
if end_index - start_index <= max_swap {
|
||||||
|
value.shuffle_swap(start_index, end_index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
value
|
||||||
|
}
|
||||||
|
|
||||||
|
fn simplify(&mut self) -> bool {
|
||||||
|
if self.simplifying_inner {
|
||||||
|
self.inner.simplify()
|
||||||
|
} else {
|
||||||
|
// Ensure that we've initialised `dist` to *something* to give
|
||||||
|
// consistent non-panicking behaviour even if called in an
|
||||||
|
// unexpected sequence.
|
||||||
|
self.force_init_dist();
|
||||||
|
if self.dist.get_mut().as_mut().unwrap().simplify() {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
self.simplifying_inner = true;
|
||||||
|
self.inner.simplify()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn complicate(&mut self) -> bool {
|
||||||
|
if self.simplifying_inner {
|
||||||
|
self.inner.complicate()
|
||||||
|
} else {
|
||||||
|
self.force_init_dist();
|
||||||
|
self.dist.get_mut().as_mut().unwrap().complicate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
77
nalgebra-sparse/tests/common/mod.rs
Normal file
77
nalgebra-sparse/tests/common/mod.rs
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
use nalgebra_sparse::csc::CscMatrix;
|
||||||
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
use nalgebra_sparse::proptest::{csc, csr};
|
||||||
|
use proptest::strategy::Strategy;
|
||||||
|
use std::convert::TryFrom;
|
||||||
|
use std::fmt::Debug;
|
||||||
|
use std::ops::RangeInclusive;
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! assert_panics {
|
||||||
|
($e:expr) => {{
|
||||||
|
use std::panic::catch_unwind;
|
||||||
|
use std::stringify;
|
||||||
|
let expr_string = stringify!($e);
|
||||||
|
|
||||||
|
// Note: We cannot manipulate the panic hook here, because it is global and the test
|
||||||
|
// suite is run in parallel, which leads to race conditions in the sense
|
||||||
|
// that some regular tests that panic might not output anything anymore.
|
||||||
|
// Unfortunately this means that output is still printed to stdout if
|
||||||
|
// we run cargo test -- --nocapture. But Cargo does not forward this if the test
|
||||||
|
// binary is not run with nocapture, so it is somewhat acceptable nonetheless.
|
||||||
|
|
||||||
|
let result = catch_unwind(|| $e);
|
||||||
|
if result.is_ok() {
|
||||||
|
panic!(
|
||||||
|
"assert_panics!({}) failed: the expression did not panic.",
|
||||||
|
expr_string
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const PROPTEST_MATRIX_DIM: RangeInclusive<usize> = 0..=6;
|
||||||
|
pub const PROPTEST_MAX_NNZ: usize = 40;
|
||||||
|
pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5..=5;
|
||||||
|
|
||||||
|
pub fn value_strategy<T>() -> RangeInclusive<T>
|
||||||
|
where
|
||||||
|
T: TryFrom<i32>,
|
||||||
|
T::Error: Debug,
|
||||||
|
{
|
||||||
|
let (start, end) = (
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY.start(),
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY.end(),
|
||||||
|
);
|
||||||
|
T::try_from(*start).unwrap()..=T::try_from(*end).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn non_zero_i32_value_strategy() -> impl Strategy<Value = i32> {
|
||||||
|
let (start, end) = (
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY.start(),
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY.end(),
|
||||||
|
);
|
||||||
|
assert!(start < &0);
|
||||||
|
assert!(end > &0);
|
||||||
|
// Note: we don't use RangeInclusive for the second range, because then we'd have different
|
||||||
|
// types, which would require boxing
|
||||||
|
(*start..0).prop_union(1..*end + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
|
||||||
|
csr(
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
PROPTEST_MATRIX_DIM,
|
||||||
|
PROPTEST_MATRIX_DIM,
|
||||||
|
PROPTEST_MAX_NNZ,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn csc_strategy() -> impl Strategy<Value = CscMatrix<i32>> {
|
||||||
|
csc(
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
PROPTEST_MATRIX_DIM,
|
||||||
|
PROPTEST_MATRIX_DIM,
|
||||||
|
PROPTEST_MAX_NNZ,
|
||||||
|
)
|
||||||
|
}
|
8
nalgebra-sparse/tests/unit.rs
Normal file
8
nalgebra-sparse/tests/unit.rs
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
//! Unit tests
|
||||||
|
#[cfg(any(not(feature = "proptest-support"), not(feature = "compare")))]
|
||||||
|
compile_error!("Tests must be run with features `proptest-support` and `compare`");
|
||||||
|
|
||||||
|
mod unit_tests;
|
||||||
|
|
||||||
|
#[macro_use]
|
||||||
|
pub mod common;
|
@ -0,0 +1,8 @@
|
|||||||
|
# Seeds for failure cases proptest has generated in the past. It is
|
||||||
|
# automatically read and these particular cases re-run before any
|
||||||
|
# novel cases are generated.
|
||||||
|
#
|
||||||
|
# It is recommended to check this file in to source control so that
|
||||||
|
# everyone who runs the test benefits from these saved cases.
|
||||||
|
cc 3f71c8edc555965e521e3aaf58c736240a0e333c3a9d54e8a836d7768c371215 # shrinks to matrix = CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0], minor_indices: [], minor_dim: 0 }, values: [] } }
|
||||||
|
cc aef645e3184b814ef39fbb10234f12e6ff502ab515dabefafeedab5895e22b12 # shrinks to (matrix, rhs) = (CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 4, 7, 11, 14], minor_indices: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 0, 2, 3], minor_dim: 4 }, values: [1.0, 0.0, 0.0, 0.0, 0.0, 40.90124126326177, 36.975170911665906, 0.0, 36.975170911665906, 42.51062858727923, -12.984115201530539, 0.0, -12.984115201530539, 27.73953543265418] } }, Matrix { data: VecStorage { data: [0.0, 0.0, 0.0, -4.05763092330143], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 1 } } })
|
117
nalgebra-sparse/tests/unit_tests/cholesky.rs
Normal file
117
nalgebra-sparse/tests/unit_tests/cholesky.rs
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
#![cfg_attr(rustfmt, rustfmt_skip)]
|
||||||
|
use crate::common::{value_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ};
|
||||||
|
use nalgebra_sparse::csc::CscMatrix;
|
||||||
|
use nalgebra_sparse::factorization::{CscCholesky};
|
||||||
|
use nalgebra_sparse::proptest::csc;
|
||||||
|
use nalgebra::{Matrix5, Vector5, Cholesky, DMatrix};
|
||||||
|
use nalgebra::proptest::matrix;
|
||||||
|
|
||||||
|
use proptest::prelude::*;
|
||||||
|
use matrixcompare::{assert_matrix_eq, prop_assert_matrix_eq};
|
||||||
|
|
||||||
|
fn positive_definite() -> impl Strategy<Value=CscMatrix<f64>> {
|
||||||
|
let csc_f64 = csc(value_strategy::<f64>(),
|
||||||
|
PROPTEST_MATRIX_DIM,
|
||||||
|
PROPTEST_MATRIX_DIM,
|
||||||
|
PROPTEST_MAX_NNZ);
|
||||||
|
csc_f64
|
||||||
|
.prop_map(|x| {
|
||||||
|
// Add a small multiple of the identity to ensure positive definiteness
|
||||||
|
x.transpose() * &x + CscMatrix::identity(x.ncols())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
proptest! {
|
||||||
|
#[test]
|
||||||
|
fn cholesky_correct_for_positive_definite_matrices(
|
||||||
|
matrix in positive_definite()
|
||||||
|
) {
|
||||||
|
let cholesky = CscCholesky::factor(&matrix).unwrap();
|
||||||
|
let l = cholesky.take_l();
|
||||||
|
let matrix_reconstructed = &l * l.transpose();
|
||||||
|
|
||||||
|
prop_assert_matrix_eq!(matrix_reconstructed, matrix, comp = abs, tol = 1e-8);
|
||||||
|
|
||||||
|
let is_lower_triangular = l.triplet_iter().all(|(i, j, _)| j <= i);
|
||||||
|
prop_assert!(is_lower_triangular);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cholesky_solve_positive_definite(
|
||||||
|
(matrix, rhs) in positive_definite()
|
||||||
|
.prop_flat_map(|csc| {
|
||||||
|
let rhs = matrix(value_strategy::<f64>(), csc.nrows(), PROPTEST_MATRIX_DIM);
|
||||||
|
(Just(csc), rhs)
|
||||||
|
})
|
||||||
|
) {
|
||||||
|
let cholesky = CscCholesky::factor(&matrix).unwrap();
|
||||||
|
|
||||||
|
// solve_mut
|
||||||
|
{
|
||||||
|
let mut x = rhs.clone();
|
||||||
|
cholesky.solve_mut(&mut x);
|
||||||
|
prop_assert_matrix_eq!(&matrix * &x, rhs, comp=abs, tol=1e-12);
|
||||||
|
}
|
||||||
|
|
||||||
|
// solve
|
||||||
|
{
|
||||||
|
let x = cholesky.solve(&rhs);
|
||||||
|
prop_assert_matrix_eq!(&matrix * &x, rhs, comp=abs, tol=1e-12);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is a test ported from nalgebra's "sparse" module, for the original CsCholesky impl
|
||||||
|
#[test]
|
||||||
|
fn cs_cholesky() {
|
||||||
|
let mut a = Matrix5::new(
|
||||||
|
40.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
2.0, 60.0, 0.0, 0.0, 0.0,
|
||||||
|
1.0, 0.0, 11.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 50.0, 0.0,
|
||||||
|
1.0, 0.0, 0.0, 4.0, 10.0
|
||||||
|
);
|
||||||
|
a.fill_upper_triangle_with_lower_triangle();
|
||||||
|
test_cholesky(a);
|
||||||
|
|
||||||
|
let a = Matrix5::from_diagonal(&Vector5::new(40.0, 60.0, 11.0, 50.0, 10.0));
|
||||||
|
test_cholesky(a);
|
||||||
|
|
||||||
|
let mut a = Matrix5::new(
|
||||||
|
40.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
2.0, 60.0, 0.0, 0.0, 0.0,
|
||||||
|
1.0, 0.0, 11.0, 0.0, 0.0,
|
||||||
|
1.0, 0.0, 0.0, 50.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 4.0, 10.0
|
||||||
|
);
|
||||||
|
a.fill_upper_triangle_with_lower_triangle();
|
||||||
|
test_cholesky(a);
|
||||||
|
|
||||||
|
let mut a = Matrix5::new(
|
||||||
|
2.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 2.0, 0.0, 0.0, 0.0,
|
||||||
|
1.0, 1.0, 2.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 2.0, 0.0,
|
||||||
|
1.0, 1.0, 0.0, 0.0, 2.0
|
||||||
|
);
|
||||||
|
a.fill_upper_triangle_with_lower_triangle();
|
||||||
|
// Test crate::new, left_looking, and up_looking implementations.
|
||||||
|
test_cholesky(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_cholesky(a: Matrix5<f64>) {
|
||||||
|
// TODO: Test "refactor"
|
||||||
|
|
||||||
|
let cs_a = CscMatrix::from(&a);
|
||||||
|
|
||||||
|
let chol_a = Cholesky::new(a).unwrap();
|
||||||
|
let chol_cs_a = CscCholesky::factor(&cs_a).unwrap();
|
||||||
|
|
||||||
|
let l = chol_a.l();
|
||||||
|
let cs_l = chol_cs_a.take_l();
|
||||||
|
|
||||||
|
let l = DMatrix::from_iterator(l.nrows(), l.ncols(), l.iter().cloned());
|
||||||
|
let cs_l_mat = DMatrix::from(&cs_l);
|
||||||
|
assert_matrix_eq!(l, cs_l_mat, comp = abs, tol = 1e-12);
|
||||||
|
}
|
@ -0,0 +1,10 @@
|
|||||||
|
# Seeds for failure cases proptest has generated in the past. It is
|
||||||
|
# automatically read and these particular cases re-run before any
|
||||||
|
# novel cases are generated.
|
||||||
|
#
|
||||||
|
# It is recommended to check this file in to source control so that
|
||||||
|
# everyone who runs the test benefits from these saved cases.
|
||||||
|
cc 07cb95127d2700ff2000157938e351ce2b43f3e6419d69b00726abfc03e682bd # shrinks to coo = CooMatrix { nrows: 4, ncols: 5, row_indices: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0], col_indices: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 4, 3], values: [1, -5, -4, -5, 1, 2, 4, -4, -4, -5, 2, -2, 4, -4] }
|
||||||
|
cc 8fdaf70d6091d89a6617573547745e9802bb9c1ce7c6ec7ad4f301cd05d54c5d # shrinks to dense = Matrix { data: VecStorage { data: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 5 } } }
|
||||||
|
cc 6961760ac7915b57a28230524cea7e9bfcea4f31790e3c0569ea74af904c2d79 # shrinks to coo = CooMatrix { nrows: 6, ncols: 6, row_indices: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0], col_indices: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0], values: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0] }
|
||||||
|
cc c9a1af218f7a974f1fda7b8909c2635d735eedbfe953082ef6b0b92702bf6d1b # shrinks to dense = Matrix { data: VecStorage { data: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], nrows: Dynamic { value: 6 }, ncols: Dynamic { value: 5 } } }
|
452
nalgebra-sparse/tests/unit_tests/convert_serial.rs
Normal file
452
nalgebra-sparse/tests/unit_tests/convert_serial.rs
Normal file
@ -0,0 +1,452 @@
|
|||||||
|
use crate::common::csc_strategy;
|
||||||
|
use nalgebra::proptest::matrix;
|
||||||
|
use nalgebra::DMatrix;
|
||||||
|
use nalgebra_sparse::convert::serial::{
|
||||||
|
convert_coo_csc, convert_coo_csr, convert_coo_dense, convert_csc_coo, convert_csc_csr,
|
||||||
|
convert_csc_dense, convert_csr_coo, convert_csr_csc, convert_csr_dense, convert_dense_coo,
|
||||||
|
convert_dense_csc, convert_dense_csr,
|
||||||
|
};
|
||||||
|
use nalgebra_sparse::coo::CooMatrix;
|
||||||
|
use nalgebra_sparse::csc::CscMatrix;
|
||||||
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
use nalgebra_sparse::proptest::{coo_no_duplicates, coo_with_duplicates, csc, csr};
|
||||||
|
use proptest::prelude::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_convert_dense_coo() {
|
||||||
|
// No duplicates
|
||||||
|
{
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let entries = &[1, 0, 3,
|
||||||
|
0, 5, 0];
|
||||||
|
// The COO representation of a dense matrix is not unique.
|
||||||
|
// Here we implicitly test that the coo matrix is indeed constructed from column-major
|
||||||
|
// iteration of the dense matrix.
|
||||||
|
let dense = DMatrix::from_row_slice(2, 3, entries);
|
||||||
|
let coo = CooMatrix::try_from_triplets(2, 3, vec![0, 1, 0], vec![0, 1, 2], vec![1, 5, 3])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(CooMatrix::from(&dense), coo);
|
||||||
|
assert_eq!(DMatrix::from(&coo), dense);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Duplicates
|
||||||
|
// No duplicates
|
||||||
|
{
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let entries = &[1, 0, 3,
|
||||||
|
0, 5, 0];
|
||||||
|
// The COO representation of a dense matrix is not unique.
|
||||||
|
// Here we implicitly test that the coo matrix is indeed constructed from column-major
|
||||||
|
// iteration of the dense matrix.
|
||||||
|
let dense = DMatrix::from_row_slice(2, 3, entries);
|
||||||
|
let coo_no_dup =
|
||||||
|
CooMatrix::try_from_triplets(2, 3, vec![0, 1, 0], vec![0, 1, 2], vec![1, 5, 3])
|
||||||
|
.unwrap();
|
||||||
|
let coo_dup = CooMatrix::try_from_triplets(
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
vec![0, 1, 0, 1],
|
||||||
|
vec![0, 1, 2, 1],
|
||||||
|
vec![1, -2, 3, 7],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(CooMatrix::from(&dense), coo_no_dup);
|
||||||
|
assert_eq!(DMatrix::from(&coo_dup), dense);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_convert_coo_csr() {
|
||||||
|
// No duplicates
|
||||||
|
{
|
||||||
|
let coo = {
|
||||||
|
let mut coo = CooMatrix::new(3, 4);
|
||||||
|
coo.push(1, 3, 4);
|
||||||
|
coo.push(0, 1, 2);
|
||||||
|
coo.push(2, 0, 1);
|
||||||
|
coo.push(2, 3, 2);
|
||||||
|
coo.push(2, 2, 1);
|
||||||
|
coo
|
||||||
|
};
|
||||||
|
|
||||||
|
let expected_csr = CsrMatrix::try_from_csr_data(
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
vec![0, 1, 2, 5],
|
||||||
|
vec![1, 3, 0, 2, 3],
|
||||||
|
vec![2, 4, 1, 1, 2],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(convert_coo_csr(&coo), expected_csr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Duplicates
|
||||||
|
{
|
||||||
|
let coo = {
|
||||||
|
let mut coo = CooMatrix::new(3, 4);
|
||||||
|
coo.push(1, 3, 4);
|
||||||
|
coo.push(2, 3, 2);
|
||||||
|
coo.push(0, 1, 2);
|
||||||
|
coo.push(2, 0, 1);
|
||||||
|
coo.push(2, 3, 2);
|
||||||
|
coo.push(0, 1, 3);
|
||||||
|
coo.push(2, 2, 1);
|
||||||
|
coo
|
||||||
|
};
|
||||||
|
|
||||||
|
let expected_csr = CsrMatrix::try_from_csr_data(
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
vec![0, 1, 2, 5],
|
||||||
|
vec![1, 3, 0, 2, 3],
|
||||||
|
vec![5, 4, 1, 1, 4],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(convert_coo_csr(&coo), expected_csr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_convert_csr_coo() {
|
||||||
|
let csr = CsrMatrix::try_from_csr_data(
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
vec![0, 1, 2, 5],
|
||||||
|
vec![1, 3, 0, 2, 3],
|
||||||
|
vec![5, 4, 1, 1, 4],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let expected_coo = CooMatrix::try_from_triplets(
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
vec![0, 1, 2, 2, 2],
|
||||||
|
vec![1, 3, 0, 2, 3],
|
||||||
|
vec![5, 4, 1, 1, 4],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(convert_csr_coo(&csr), expected_coo);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_convert_coo_csc() {
|
||||||
|
// No duplicates
|
||||||
|
{
|
||||||
|
let coo = {
|
||||||
|
let mut coo = CooMatrix::new(3, 4);
|
||||||
|
coo.push(1, 3, 4);
|
||||||
|
coo.push(0, 1, 2);
|
||||||
|
coo.push(2, 0, 1);
|
||||||
|
coo.push(2, 3, 2);
|
||||||
|
coo.push(2, 2, 1);
|
||||||
|
coo
|
||||||
|
};
|
||||||
|
|
||||||
|
let expected_csc = CscMatrix::try_from_csc_data(
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
vec![0, 1, 2, 3, 5],
|
||||||
|
vec![2, 0, 2, 1, 2],
|
||||||
|
vec![1, 2, 1, 4, 2],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(convert_coo_csc(&coo), expected_csc);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Duplicates
|
||||||
|
{
|
||||||
|
let coo = {
|
||||||
|
let mut coo = CooMatrix::new(3, 4);
|
||||||
|
coo.push(1, 3, 4);
|
||||||
|
coo.push(2, 3, 2);
|
||||||
|
coo.push(0, 1, 2);
|
||||||
|
coo.push(2, 0, 1);
|
||||||
|
coo.push(2, 3, 2);
|
||||||
|
coo.push(0, 1, 3);
|
||||||
|
coo.push(2, 2, 1);
|
||||||
|
coo
|
||||||
|
};
|
||||||
|
|
||||||
|
let expected_csc = CscMatrix::try_from_csc_data(
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
vec![0, 1, 2, 3, 5],
|
||||||
|
vec![2, 0, 2, 1, 2],
|
||||||
|
vec![1, 5, 1, 4, 4],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(convert_coo_csc(&coo), expected_csc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_convert_csc_coo() {
|
||||||
|
let csc = CscMatrix::try_from_csc_data(
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
vec![0, 1, 2, 3, 5],
|
||||||
|
vec![2, 0, 2, 1, 2],
|
||||||
|
vec![1, 2, 1, 4, 2],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let expected_coo = CooMatrix::try_from_triplets(
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
vec![2, 0, 2, 1, 2],
|
||||||
|
vec![0, 1, 2, 3, 3],
|
||||||
|
vec![1, 2, 1, 4, 2],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(convert_csc_coo(&csc), expected_coo);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_convert_csr_csc_bidirectional() {
|
||||||
|
let csr = CsrMatrix::try_from_csr_data(
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
vec![0, 3, 4, 6],
|
||||||
|
vec![1, 2, 3, 0, 1, 3],
|
||||||
|
vec![5, 3, 2, 2, 1, 4],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let csc = CscMatrix::try_from_csc_data(
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
vec![0, 1, 3, 4, 6],
|
||||||
|
vec![1, 0, 2, 0, 0, 2],
|
||||||
|
vec![2, 5, 1, 3, 2, 4],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(convert_csr_csc(&csr), csc);
|
||||||
|
assert_eq!(convert_csc_csr(&csc), csr);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_convert_csr_dense_bidirectional() {
|
||||||
|
let csr = CsrMatrix::try_from_csr_data(
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
vec![0, 3, 4, 6],
|
||||||
|
vec![1, 2, 3, 0, 1, 3],
|
||||||
|
vec![5, 3, 2, 2, 1, 4],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let dense = DMatrix::from_row_slice(3, 4, &[
|
||||||
|
0, 5, 3, 2,
|
||||||
|
2, 0, 0, 0,
|
||||||
|
0, 1, 0, 4
|
||||||
|
]);
|
||||||
|
|
||||||
|
assert_eq!(convert_csr_dense(&csr), dense);
|
||||||
|
assert_eq!(convert_dense_csr(&dense), csr);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_convert_csc_dense_bidirectional() {
|
||||||
|
let csc = CscMatrix::try_from_csc_data(
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
vec![0, 1, 3, 4, 6],
|
||||||
|
vec![1, 0, 2, 0, 0, 2],
|
||||||
|
vec![2, 5, 1, 3, 2, 4],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let dense = DMatrix::from_row_slice(3, 4, &[
|
||||||
|
0, 5, 3, 2,
|
||||||
|
2, 0, 0, 0,
|
||||||
|
0, 1, 0, 4
|
||||||
|
]);
|
||||||
|
|
||||||
|
assert_eq!(convert_csc_dense(&csc), dense);
|
||||||
|
assert_eq!(convert_dense_csc(&dense), csc);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn coo_strategy() -> impl Strategy<Value = CooMatrix<i32>> {
|
||||||
|
coo_with_duplicates(-5..=5, 0..=6usize, 0..=6usize, 40, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn coo_no_duplicates_strategy() -> impl Strategy<Value = CooMatrix<i32>> {
|
||||||
|
coo_no_duplicates(-5..=5, 0..=6usize, 0..=6usize, 40)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
|
||||||
|
csr(-5..=5, 0..=6usize, 0..=6usize, 40)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns
|
||||||
|
fn non_zero_csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
|
||||||
|
csr(1..=5, 0..=6usize, 0..=6usize, 40)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns
|
||||||
|
fn non_zero_csc_strategy() -> impl Strategy<Value = CscMatrix<i32>> {
|
||||||
|
csc(1..=5, 0..=6usize, 0..=6usize, 40)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dense_strategy() -> impl Strategy<Value = DMatrix<i32>> {
|
||||||
|
matrix(-5..=5, 0..=6, 0..=6)
|
||||||
|
}
|
||||||
|
|
||||||
|
proptest! {
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_dense_coo_roundtrip(dense in matrix(-5 ..= 5, 0 ..=6, 0..=6)) {
|
||||||
|
let coo = convert_dense_coo(&dense);
|
||||||
|
let dense2 = convert_coo_dense(&coo);
|
||||||
|
prop_assert_eq!(&dense, &dense2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_coo_dense_coo_roundtrip(coo in coo_strategy()) {
|
||||||
|
// We cannot compare the result of the roundtrip coo -> dense -> coo directly for
|
||||||
|
// two reasons:
|
||||||
|
// 1. the COO matrices will generally have different ordering of elements
|
||||||
|
// 2. explicitly stored zero entries in the original matrix will be discarded
|
||||||
|
// when converting back to COO
|
||||||
|
// Therefore we instead compare the results of converting the COO matrix
|
||||||
|
// at the end of the roundtrip with its dense representation
|
||||||
|
let dense = convert_coo_dense(&coo);
|
||||||
|
let coo2 = convert_dense_coo(&dense);
|
||||||
|
let dense2 = convert_coo_dense(&coo2);
|
||||||
|
prop_assert_eq!(dense, dense2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn coo_from_dense_roundtrip(dense in dense_strategy()) {
|
||||||
|
prop_assert_eq!(&dense, &DMatrix::from(&CooMatrix::from(&dense)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_coo_csr_agrees_with_csr_dense(coo in coo_strategy()) {
|
||||||
|
let coo_dense = convert_coo_dense(&coo);
|
||||||
|
let csr = convert_coo_csr(&coo);
|
||||||
|
let csr_dense = convert_csr_dense(&csr);
|
||||||
|
prop_assert_eq!(csr_dense, coo_dense);
|
||||||
|
|
||||||
|
// It might be that COO matrices have a higher nnz due to duplicates,
|
||||||
|
// so we can only check that the CSR matrix has no more than the original COO matrix
|
||||||
|
prop_assert!(csr.nnz() <= coo.nnz());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_coo_csr_nnz(coo in coo_no_duplicates_strategy()) {
|
||||||
|
// Check that the NNZ are equal when converting from a CooMatrix without
|
||||||
|
// duplicates to a CSR matrix
|
||||||
|
let csr = convert_coo_csr(&coo);
|
||||||
|
prop_assert_eq!(csr.nnz(), coo.nnz());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_csr_coo_roundtrip(csr in csr_strategy()) {
|
||||||
|
let coo = convert_csr_coo(&csr);
|
||||||
|
let csr2 = convert_coo_csr(&coo);
|
||||||
|
prop_assert_eq!(csr2, csr);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn coo_from_csr_roundtrip(csr in csr_strategy()) {
|
||||||
|
prop_assert_eq!(&csr, &CsrMatrix::from(&CooMatrix::from(&csr)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_from_dense_roundtrip(dense in dense_strategy()) {
|
||||||
|
prop_assert_eq!(&dense, &DMatrix::from(&CsrMatrix::from(&dense)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_csr_dense_roundtrip(csr in non_zero_csr_strategy()) {
|
||||||
|
// Since we only generate CSR matrices with non-zero values, we know that the
|
||||||
|
// number of explicitly stored entries when converting CSR->Dense->CSR should be
|
||||||
|
// unchanged, so that we can verify that the result is the same as the input
|
||||||
|
let dense = convert_csr_dense(&csr);
|
||||||
|
let csr2 = convert_dense_csr(&dense);
|
||||||
|
prop_assert_eq!(csr2, csr);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_csc_coo_roundtrip(csc in csc_strategy()) {
|
||||||
|
let coo = convert_csc_coo(&csc);
|
||||||
|
let csc2 = convert_coo_csc(&coo);
|
||||||
|
prop_assert_eq!(csc2, csc);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn coo_from_csc_roundtrip(csc in csc_strategy()) {
|
||||||
|
prop_assert_eq!(&csc, &CscMatrix::from(&CooMatrix::from(&csc)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_csc_dense_roundtrip(csc in non_zero_csc_strategy()) {
|
||||||
|
// Since we only generate CSC matrices with non-zero values, we know that the
|
||||||
|
// number of explicitly stored entries when converting CSC->Dense->CSC should be
|
||||||
|
// unchanged, so that we can verify that the result is the same as the input
|
||||||
|
let dense = convert_csc_dense(&csc);
|
||||||
|
let csc2 = convert_dense_csc(&dense);
|
||||||
|
prop_assert_eq!(csc2, csc);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_from_dense_roundtrip(dense in dense_strategy()) {
|
||||||
|
prop_assert_eq!(&dense, &DMatrix::from(&CscMatrix::from(&dense)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_coo_csc_agrees_with_csc_dense(coo in coo_strategy()) {
|
||||||
|
let coo_dense = convert_coo_dense(&coo);
|
||||||
|
let csc = convert_coo_csc(&coo);
|
||||||
|
let csc_dense = convert_csc_dense(&csc);
|
||||||
|
prop_assert_eq!(csc_dense, coo_dense);
|
||||||
|
|
||||||
|
// It might be that COO matrices have a higher nnz due to duplicates,
|
||||||
|
// so we can only check that the CSR matrix has no more than the original COO matrix
|
||||||
|
prop_assert!(csc.nnz() <= coo.nnz());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_coo_csc_nnz(coo in coo_no_duplicates_strategy()) {
|
||||||
|
// Check that the NNZ are equal when converting from a CooMatrix without
|
||||||
|
// duplicates to a CSR matrix
|
||||||
|
let csc = convert_coo_csc(&coo);
|
||||||
|
prop_assert_eq!(csc.nnz(), coo.nnz());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_csc_csr_roundtrip(csc in csc_strategy()) {
|
||||||
|
let csr = convert_csc_csr(&csc);
|
||||||
|
let csc2 = convert_csr_csc(&csr);
|
||||||
|
prop_assert_eq!(csc2, csc);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_csr_csc_roundtrip(csr in csr_strategy()) {
|
||||||
|
let csc = convert_csr_csc(&csr);
|
||||||
|
let csr2 = convert_csc_csr(&csc);
|
||||||
|
prop_assert_eq!(csr2, csr);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_from_csr_roundtrip(csr in csr_strategy()) {
|
||||||
|
prop_assert_eq!(&csr, &CsrMatrix::from(&CscMatrix::from(&csr)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_from_csc_roundtrip(csc in csc_strategy()) {
|
||||||
|
prop_assert_eq!(&csc, &CscMatrix::from(&CsrMatrix::from(&csc)));
|
||||||
|
}
|
||||||
|
}
|
254
nalgebra-sparse/tests/unit_tests/coo.rs
Normal file
254
nalgebra-sparse/tests/unit_tests/coo.rs
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
use crate::assert_panics;
|
||||||
|
use nalgebra::DMatrix;
|
||||||
|
use nalgebra_sparse::coo::CooMatrix;
|
||||||
|
use nalgebra_sparse::SparseFormatErrorKind;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn coo_construction_for_valid_data() {
|
||||||
|
// Test that construction with try_from_triplets succeeds, that the state of the
|
||||||
|
// matrix afterwards is as expected, and that the dense representation matches expectations.
|
||||||
|
|
||||||
|
{
|
||||||
|
// Zero matrix
|
||||||
|
let coo =
|
||||||
|
CooMatrix::<i32>::try_from_triplets(3, 2, Vec::new(), Vec::new(), Vec::new()).unwrap();
|
||||||
|
assert_eq!(coo.nrows(), 3);
|
||||||
|
assert_eq!(coo.ncols(), 2);
|
||||||
|
assert!(coo.triplet_iter().next().is_none());
|
||||||
|
assert!(coo.row_indices().is_empty());
|
||||||
|
assert!(coo.col_indices().is_empty());
|
||||||
|
assert!(coo.values().is_empty());
|
||||||
|
|
||||||
|
assert_eq!(DMatrix::from(&coo), DMatrix::repeat(3, 2, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Arbitrary matrix, no duplicates
|
||||||
|
let i = vec![0, 1, 0, 0, 2];
|
||||||
|
let j = vec![0, 2, 1, 3, 3];
|
||||||
|
let v = vec![2, 3, 7, 3, 1];
|
||||||
|
let coo =
|
||||||
|
CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap();
|
||||||
|
assert_eq!(coo.nrows(), 3);
|
||||||
|
assert_eq!(coo.ncols(), 5);
|
||||||
|
|
||||||
|
assert_eq!(i.as_slice(), coo.row_indices());
|
||||||
|
assert_eq!(j.as_slice(), coo.col_indices());
|
||||||
|
assert_eq!(v.as_slice(), coo.values());
|
||||||
|
|
||||||
|
let expected_triplets: Vec<_> = i
|
||||||
|
.iter()
|
||||||
|
.zip(&j)
|
||||||
|
.zip(&v)
|
||||||
|
.map(|((i, j), v)| (*i, *j, *v))
|
||||||
|
.collect();
|
||||||
|
let actual_triplets: Vec<_> = coo.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
|
||||||
|
assert_eq!(actual_triplets, expected_triplets);
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let expected_dense = DMatrix::from_row_slice(3, 5, &[
|
||||||
|
2, 7, 0, 3, 0,
|
||||||
|
0, 0, 3, 0, 0,
|
||||||
|
0, 0, 0, 1, 0
|
||||||
|
]);
|
||||||
|
assert_eq!(DMatrix::from(&coo), expected_dense);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Arbitrary matrix, with duplicates
|
||||||
|
let i = vec![0, 1, 0, 0, 0, 0, 2, 1];
|
||||||
|
let j = vec![0, 2, 0, 1, 0, 3, 3, 2];
|
||||||
|
let v = vec![2, 3, 4, 7, 1, 3, 1, 5];
|
||||||
|
let coo =
|
||||||
|
CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap();
|
||||||
|
assert_eq!(coo.nrows(), 3);
|
||||||
|
assert_eq!(coo.ncols(), 5);
|
||||||
|
|
||||||
|
assert_eq!(i.as_slice(), coo.row_indices());
|
||||||
|
assert_eq!(j.as_slice(), coo.col_indices());
|
||||||
|
assert_eq!(v.as_slice(), coo.values());
|
||||||
|
|
||||||
|
let expected_triplets: Vec<_> = i
|
||||||
|
.iter()
|
||||||
|
.zip(&j)
|
||||||
|
.zip(&v)
|
||||||
|
.map(|((i, j), v)| (*i, *j, *v))
|
||||||
|
.collect();
|
||||||
|
let actual_triplets: Vec<_> = coo.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
|
||||||
|
assert_eq!(actual_triplets, expected_triplets);
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let expected_dense = DMatrix::from_row_slice(3, 5, &[
|
||||||
|
7, 7, 0, 3, 0,
|
||||||
|
0, 0, 8, 0, 0,
|
||||||
|
0, 0, 0, 1, 0
|
||||||
|
]);
|
||||||
|
assert_eq!(DMatrix::from(&coo), expected_dense);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn coo_try_from_triplets_reports_out_of_bounds_indices() {
|
||||||
|
{
|
||||||
|
// 0x0 matrix
|
||||||
|
let result = CooMatrix::<i32>::try_from_triplets(0, 0, vec![0], vec![0], vec![2]);
|
||||||
|
assert!(matches!(
|
||||||
|
result.unwrap_err().kind(),
|
||||||
|
SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// 1x1 matrix, row out of bounds
|
||||||
|
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![0], vec![2]);
|
||||||
|
assert!(matches!(
|
||||||
|
result.unwrap_err().kind(),
|
||||||
|
SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// 1x1 matrix, col out of bounds
|
||||||
|
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![0], vec![1], vec![2]);
|
||||||
|
assert!(matches!(
|
||||||
|
result.unwrap_err().kind(),
|
||||||
|
SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// 1x1 matrix, row and col out of bounds
|
||||||
|
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![1], vec![2]);
|
||||||
|
assert!(matches!(
|
||||||
|
result.unwrap_err().kind(),
|
||||||
|
SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Arbitrary matrix, row out of bounds
|
||||||
|
let i = vec![0, 1, 0, 3, 2];
|
||||||
|
let j = vec![0, 2, 1, 3, 3];
|
||||||
|
let v = vec![2, 3, 7, 3, 1];
|
||||||
|
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
||||||
|
assert!(matches!(
|
||||||
|
result.unwrap_err().kind(),
|
||||||
|
SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Arbitrary matrix, col out of bounds
|
||||||
|
let i = vec![0, 1, 0, 0, 2];
|
||||||
|
let j = vec![0, 2, 1, 5, 3];
|
||||||
|
let v = vec![2, 3, 7, 3, 1];
|
||||||
|
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
||||||
|
assert!(matches!(
|
||||||
|
result.unwrap_err().kind(),
|
||||||
|
SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn coo_try_from_triplets_panics_on_mismatched_vectors() {
|
||||||
|
// Check that try_from_triplets panics when the triplet vectors have different lengths
|
||||||
|
macro_rules! assert_errs {
|
||||||
|
($result:expr) => {
|
||||||
|
assert!(matches!(
|
||||||
|
$result.unwrap_err().kind(),
|
||||||
|
SparseFormatErrorKind::InvalidStructure
|
||||||
|
))
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
vec![1, 2],
|
||||||
|
vec![0],
|
||||||
|
vec![0]
|
||||||
|
));
|
||||||
|
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
vec![1],
|
||||||
|
vec![0, 0],
|
||||||
|
vec![0]
|
||||||
|
));
|
||||||
|
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
vec![1],
|
||||||
|
vec![0],
|
||||||
|
vec![0, 1]
|
||||||
|
));
|
||||||
|
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
vec![1, 2],
|
||||||
|
vec![0, 1],
|
||||||
|
vec![0]
|
||||||
|
));
|
||||||
|
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
vec![1],
|
||||||
|
vec![0, 1],
|
||||||
|
vec![0, 1]
|
||||||
|
));
|
||||||
|
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
vec![1, 1],
|
||||||
|
vec![0],
|
||||||
|
vec![0, 1]
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn coo_push_valid_entries() {
|
||||||
|
let mut coo = CooMatrix::new(3, 3);
|
||||||
|
|
||||||
|
coo.push(0, 0, 1);
|
||||||
|
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1)]);
|
||||||
|
|
||||||
|
coo.push(0, 0, 2);
|
||||||
|
assert_eq!(
|
||||||
|
coo.triplet_iter().collect::<Vec<_>>(),
|
||||||
|
vec![(0, 0, &1), (0, 0, &2)]
|
||||||
|
);
|
||||||
|
|
||||||
|
coo.push(2, 2, 3);
|
||||||
|
assert_eq!(
|
||||||
|
coo.triplet_iter().collect::<Vec<_>>(),
|
||||||
|
vec![(0, 0, &1), (0, 0, &2), (2, 2, &3)]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn coo_push_out_of_bounds_entries() {
|
||||||
|
{
|
||||||
|
// 0x0 matrix
|
||||||
|
let coo = CooMatrix::new(0, 0);
|
||||||
|
assert_panics!(coo.clone().push(0, 0, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// 0x1 matrix
|
||||||
|
assert_panics!(CooMatrix::new(0, 1).push(0, 0, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// 1x0 matrix
|
||||||
|
assert_panics!(CooMatrix::new(1, 0).push(0, 0, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Arbitrary matrix dimensions
|
||||||
|
let coo = CooMatrix::new(3, 2);
|
||||||
|
assert_panics!(coo.clone().push(3, 0, 1));
|
||||||
|
assert_panics!(coo.clone().push(2, 2, 1));
|
||||||
|
assert_panics!(coo.clone().push(3, 2, 1));
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,7 @@
|
|||||||
|
# Seeds for failure cases proptest has generated in the past. It is
|
||||||
|
# automatically read and these particular cases re-run before any
|
||||||
|
# novel cases are generated.
|
||||||
|
#
|
||||||
|
# It is recommended to check this file in to source control so that
|
||||||
|
# everyone who runs the test benefits from these saved cases.
|
||||||
|
cc a71b4654827840ed539b82cd7083615b0fb3f75933de6a7d91d8148a2bf34960 # shrinks to (csc, triplet_subset) = (CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 1, 1, 1, 1, 1, 1], minor_indices: [0], minor_dim: 4 }, values: [0] } }, {})
|
605
nalgebra-sparse/tests/unit_tests/csc.rs
Normal file
605
nalgebra-sparse/tests/unit_tests/csc.rs
Normal file
@ -0,0 +1,605 @@
|
|||||||
|
use nalgebra::DMatrix;
|
||||||
|
use nalgebra_sparse::csc::CscMatrix;
|
||||||
|
use nalgebra_sparse::{SparseEntry, SparseEntryMut, SparseFormatErrorKind};
|
||||||
|
|
||||||
|
use proptest::prelude::*;
|
||||||
|
use proptest::sample::subsequence;
|
||||||
|
|
||||||
|
use crate::assert_panics;
|
||||||
|
use crate::common::csc_strategy;
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_matrix_valid_data() {
|
||||||
|
// Construct matrix from valid data and check that selected methods return results
|
||||||
|
// that agree with expectations.
|
||||||
|
|
||||||
|
{
|
||||||
|
// A CSC matrix with zero explicitly stored entries
|
||||||
|
let offsets = vec![0, 0, 0, 0];
|
||||||
|
let indices = vec![];
|
||||||
|
let values = Vec::<i32>::new();
|
||||||
|
let mut matrix = CscMatrix::try_from_csc_data(2, 3, offsets, indices, values).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(matrix, CscMatrix::zeros(2, 3));
|
||||||
|
|
||||||
|
assert_eq!(matrix.nrows(), 2);
|
||||||
|
assert_eq!(matrix.ncols(), 3);
|
||||||
|
assert_eq!(matrix.nnz(), 0);
|
||||||
|
assert_eq!(matrix.col_offsets(), &[0, 0, 0, 0]);
|
||||||
|
assert_eq!(matrix.row_indices(), &[]);
|
||||||
|
assert_eq!(matrix.values(), &[]);
|
||||||
|
|
||||||
|
assert!(matrix.triplet_iter().next().is_none());
|
||||||
|
assert!(matrix.triplet_iter_mut().next().is_none());
|
||||||
|
|
||||||
|
assert_eq!(matrix.col(0).nrows(), 2);
|
||||||
|
assert_eq!(matrix.col(0).nnz(), 0);
|
||||||
|
assert_eq!(matrix.col(0).row_indices(), &[]);
|
||||||
|
assert_eq!(matrix.col(0).values(), &[]);
|
||||||
|
assert_eq!(matrix.col_mut(0).nrows(), 2);
|
||||||
|
assert_eq!(matrix.col_mut(0).nnz(), 0);
|
||||||
|
assert_eq!(matrix.col_mut(0).row_indices(), &[]);
|
||||||
|
assert_eq!(matrix.col_mut(0).values(), &[]);
|
||||||
|
assert_eq!(matrix.col_mut(0).values_mut(), &[]);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.col_mut(0).rows_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(matrix.col(1).nrows(), 2);
|
||||||
|
assert_eq!(matrix.col(1).nnz(), 0);
|
||||||
|
assert_eq!(matrix.col(1).row_indices(), &[]);
|
||||||
|
assert_eq!(matrix.col(1).values(), &[]);
|
||||||
|
assert_eq!(matrix.col_mut(1).nrows(), 2);
|
||||||
|
assert_eq!(matrix.col_mut(1).nnz(), 0);
|
||||||
|
assert_eq!(matrix.col_mut(1).row_indices(), &[]);
|
||||||
|
assert_eq!(matrix.col_mut(1).values(), &[]);
|
||||||
|
assert_eq!(matrix.col_mut(1).values_mut(), &[]);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.col_mut(1).rows_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(matrix.col(2).nrows(), 2);
|
||||||
|
assert_eq!(matrix.col(2).nnz(), 0);
|
||||||
|
assert_eq!(matrix.col(2).row_indices(), &[]);
|
||||||
|
assert_eq!(matrix.col(2).values(), &[]);
|
||||||
|
assert_eq!(matrix.col_mut(2).nrows(), 2);
|
||||||
|
assert_eq!(matrix.col_mut(2).nnz(), 0);
|
||||||
|
assert_eq!(matrix.col_mut(2).row_indices(), &[]);
|
||||||
|
assert_eq!(matrix.col_mut(2).values(), &[]);
|
||||||
|
assert_eq!(matrix.col_mut(2).values_mut(), &[]);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.col_mut(2).rows_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(matrix.get_col(3).is_none());
|
||||||
|
assert!(matrix.get_col_mut(3).is_none());
|
||||||
|
|
||||||
|
let (offsets, indices, values) = matrix.disassemble();
|
||||||
|
|
||||||
|
assert_eq!(offsets, vec![0, 0, 0, 0]);
|
||||||
|
assert_eq!(indices, vec![]);
|
||||||
|
assert_eq!(values, vec![]);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// An arbitrary CSC matrix
|
||||||
|
let offsets = vec![0, 2, 2, 5];
|
||||||
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let mut matrix =
|
||||||
|
CscMatrix::try_from_csc_data(6, 3, offsets.clone(), indices.clone(), values.clone())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(matrix.nrows(), 6);
|
||||||
|
assert_eq!(matrix.ncols(), 3);
|
||||||
|
assert_eq!(matrix.nnz(), 5);
|
||||||
|
assert_eq!(matrix.col_offsets(), &[0, 2, 2, 5]);
|
||||||
|
assert_eq!(matrix.row_indices(), &[0, 5, 1, 2, 3]);
|
||||||
|
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
|
||||||
|
|
||||||
|
let expected_triplets = vec![(0, 0, 0), (5, 0, 1), (1, 2, 2), (2, 2, 3), (3, 2, 4)];
|
||||||
|
assert_eq!(
|
||||||
|
matrix
|
||||||
|
.triplet_iter()
|
||||||
|
.map(|(i, j, v)| (i, j, *v))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
expected_triplets
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
matrix
|
||||||
|
.triplet_iter_mut()
|
||||||
|
.map(|(i, j, v)| (i, j, *v))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
expected_triplets
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(matrix.col(0).nrows(), 6);
|
||||||
|
assert_eq!(matrix.col(0).nnz(), 2);
|
||||||
|
assert_eq!(matrix.col(0).row_indices(), &[0, 5]);
|
||||||
|
assert_eq!(matrix.col(0).values(), &[0, 1]);
|
||||||
|
assert_eq!(matrix.col_mut(0).nrows(), 6);
|
||||||
|
assert_eq!(matrix.col_mut(0).nnz(), 2);
|
||||||
|
assert_eq!(matrix.col_mut(0).row_indices(), &[0, 5]);
|
||||||
|
assert_eq!(matrix.col_mut(0).values(), &[0, 1]);
|
||||||
|
assert_eq!(matrix.col_mut(0).values_mut(), &[0, 1]);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.col_mut(0).rows_and_values_mut(),
|
||||||
|
([0, 5].as_ref(), [0, 1].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(matrix.col(1).nrows(), 6);
|
||||||
|
assert_eq!(matrix.col(1).nnz(), 0);
|
||||||
|
assert_eq!(matrix.col(1).row_indices(), &[]);
|
||||||
|
assert_eq!(matrix.col(1).values(), &[]);
|
||||||
|
assert_eq!(matrix.col_mut(1).nrows(), 6);
|
||||||
|
assert_eq!(matrix.col_mut(1).nnz(), 0);
|
||||||
|
assert_eq!(matrix.col_mut(1).row_indices(), &[]);
|
||||||
|
assert_eq!(matrix.col_mut(1).values(), &[]);
|
||||||
|
assert_eq!(matrix.col_mut(1).values_mut(), &[]);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.col_mut(1).rows_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(matrix.col(2).nrows(), 6);
|
||||||
|
assert_eq!(matrix.col(2).nnz(), 3);
|
||||||
|
assert_eq!(matrix.col(2).row_indices(), &[1, 2, 3]);
|
||||||
|
assert_eq!(matrix.col(2).values(), &[2, 3, 4]);
|
||||||
|
assert_eq!(matrix.col_mut(2).nrows(), 6);
|
||||||
|
assert_eq!(matrix.col_mut(2).nnz(), 3);
|
||||||
|
assert_eq!(matrix.col_mut(2).row_indices(), &[1, 2, 3]);
|
||||||
|
assert_eq!(matrix.col_mut(2).values(), &[2, 3, 4]);
|
||||||
|
assert_eq!(matrix.col_mut(2).values_mut(), &[2, 3, 4]);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.col_mut(2).rows_and_values_mut(),
|
||||||
|
([1, 2, 3].as_ref(), [2, 3, 4].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(matrix.get_col(3).is_none());
|
||||||
|
assert!(matrix.get_col_mut(3).is_none());
|
||||||
|
|
||||||
|
let (offsets2, indices2, values2) = matrix.disassemble();
|
||||||
|
|
||||||
|
assert_eq!(offsets2, offsets);
|
||||||
|
assert_eq!(indices2, indices);
|
||||||
|
assert_eq!(values2, values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_matrix_try_from_invalid_csc_data() {
|
||||||
|
{
|
||||||
|
// Empty offset array (invalid length)
|
||||||
|
let matrix = CscMatrix::try_from_csc_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Offset array invalid length for arbitrary data
|
||||||
|
let offsets = vec![0, 3, 5];
|
||||||
|
let indices = vec![0, 1, 2, 3, 5];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
|
||||||
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Invalid first entry in offsets array
|
||||||
|
let offsets = vec![1, 2, 2, 5];
|
||||||
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Invalid last entry in offsets array
|
||||||
|
let offsets = vec![0, 2, 2, 4];
|
||||||
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Invalid length of offsets array
|
||||||
|
let offsets = vec![0, 2, 2];
|
||||||
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Nonmonotonic offsets
|
||||||
|
let offsets = vec![0, 3, 2, 5];
|
||||||
|
let indices = vec![0, 1, 2, 3, 4];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Nonmonotonic minor indices
|
||||||
|
let offsets = vec![0, 2, 2, 5];
|
||||||
|
let indices = vec![0, 2, 3, 1, 4];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Minor index out of bounds
|
||||||
|
let offsets = vec![0, 2, 2, 5];
|
||||||
|
let indices = vec![0, 6, 1, 2, 3];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Duplicate entry
|
||||||
|
let offsets = vec![0, 2, 2, 5];
|
||||||
|
let indices = vec![0, 5, 2, 2, 3];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::DuplicateEntry
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_disassemble_avoids_clone_when_owned() {
|
||||||
|
// Test that disassemble avoids cloning the sparsity pattern when it holds the sole reference
|
||||||
|
// to the pattern. We do so by checking that the pointer to the data is unchanged.
|
||||||
|
|
||||||
|
let offsets = vec![0, 2, 2, 5];
|
||||||
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let offsets_ptr = offsets.as_ptr();
|
||||||
|
let indices_ptr = indices.as_ptr();
|
||||||
|
let values_ptr = values.as_ptr();
|
||||||
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values).unwrap();
|
||||||
|
|
||||||
|
let (offsets, indices, values) = matrix.disassemble();
|
||||||
|
assert_eq!(offsets.as_ptr(), offsets_ptr);
|
||||||
|
assert_eq!(indices.as_ptr(), indices_ptr);
|
||||||
|
assert_eq!(values.as_ptr(), values_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rustfmt makes this test much harder to read by expanding some of the one-liners to 4-liners,
|
||||||
|
// so for now we skip rustfmt...
|
||||||
|
#[rustfmt::skip]
|
||||||
|
#[test]
|
||||||
|
fn csc_matrix_get_index_entry() {
|
||||||
|
// Test .get_entry(_mut) and .index_entry(_mut) methods
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let dense = DMatrix::from_row_slice(2, 3, &[
|
||||||
|
1, 0, 3,
|
||||||
|
0, 5, 6
|
||||||
|
]);
|
||||||
|
let csc = CscMatrix::from(&dense);
|
||||||
|
|
||||||
|
assert_eq!(csc.get_entry(0, 0), Some(SparseEntry::NonZero(&1)));
|
||||||
|
assert_eq!(csc.index_entry(0, 0), SparseEntry::NonZero(&1));
|
||||||
|
assert_eq!(csc.get_entry(0, 1), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(csc.index_entry(0, 1), SparseEntry::Zero);
|
||||||
|
assert_eq!(csc.get_entry(0, 2), Some(SparseEntry::NonZero(&3)));
|
||||||
|
assert_eq!(csc.index_entry(0, 2), SparseEntry::NonZero(&3));
|
||||||
|
assert_eq!(csc.get_entry(1, 0), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(csc.index_entry(1, 0), SparseEntry::Zero);
|
||||||
|
assert_eq!(csc.get_entry(1, 1), Some(SparseEntry::NonZero(&5)));
|
||||||
|
assert_eq!(csc.index_entry(1, 1), SparseEntry::NonZero(&5));
|
||||||
|
assert_eq!(csc.get_entry(1, 2), Some(SparseEntry::NonZero(&6)));
|
||||||
|
assert_eq!(csc.index_entry(1, 2), SparseEntry::NonZero(&6));
|
||||||
|
|
||||||
|
// Check some out of bounds with .get_entry
|
||||||
|
assert_eq!(csc.get_entry(0, 3), None);
|
||||||
|
assert_eq!(csc.get_entry(0, 4), None);
|
||||||
|
assert_eq!(csc.get_entry(1, 3), None);
|
||||||
|
assert_eq!(csc.get_entry(1, 4), None);
|
||||||
|
assert_eq!(csc.get_entry(2, 0), None);
|
||||||
|
assert_eq!(csc.get_entry(2, 1), None);
|
||||||
|
assert_eq!(csc.get_entry(2, 2), None);
|
||||||
|
assert_eq!(csc.get_entry(2, 3), None);
|
||||||
|
assert_eq!(csc.get_entry(2, 4), None);
|
||||||
|
|
||||||
|
// Check that out of bounds with .index_entry panics
|
||||||
|
assert_panics!(csc.index_entry(0, 3));
|
||||||
|
assert_panics!(csc.index_entry(0, 4));
|
||||||
|
assert_panics!(csc.index_entry(1, 3));
|
||||||
|
assert_panics!(csc.index_entry(1, 4));
|
||||||
|
assert_panics!(csc.index_entry(2, 0));
|
||||||
|
assert_panics!(csc.index_entry(2, 1));
|
||||||
|
assert_panics!(csc.index_entry(2, 2));
|
||||||
|
assert_panics!(csc.index_entry(2, 3));
|
||||||
|
assert_panics!(csc.index_entry(2, 4));
|
||||||
|
|
||||||
|
{
|
||||||
|
// Check mutable versions of the above functions
|
||||||
|
let mut csc = csc;
|
||||||
|
|
||||||
|
assert_eq!(csc.get_entry_mut(0, 0), Some(SparseEntryMut::NonZero(&mut 1)));
|
||||||
|
assert_eq!(csc.index_entry_mut(0, 0), SparseEntryMut::NonZero(&mut 1));
|
||||||
|
assert_eq!(csc.get_entry_mut(0, 1), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(csc.index_entry_mut(0, 1), SparseEntryMut::Zero);
|
||||||
|
assert_eq!(csc.get_entry_mut(0, 2), Some(SparseEntryMut::NonZero(&mut 3)));
|
||||||
|
assert_eq!(csc.index_entry_mut(0, 2), SparseEntryMut::NonZero(&mut 3));
|
||||||
|
assert_eq!(csc.get_entry_mut(1, 0), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(csc.index_entry_mut(1, 0), SparseEntryMut::Zero);
|
||||||
|
assert_eq!(csc.get_entry_mut(1, 1), Some(SparseEntryMut::NonZero(&mut 5)));
|
||||||
|
assert_eq!(csc.index_entry_mut(1, 1), SparseEntryMut::NonZero(&mut 5));
|
||||||
|
assert_eq!(csc.get_entry_mut(1, 2), Some(SparseEntryMut::NonZero(&mut 6)));
|
||||||
|
assert_eq!(csc.index_entry_mut(1, 2), SparseEntryMut::NonZero(&mut 6));
|
||||||
|
|
||||||
|
// Check some out of bounds with .get_entry_mut
|
||||||
|
assert_eq!(csc.get_entry_mut(0, 3), None);
|
||||||
|
assert_eq!(csc.get_entry_mut(0, 4), None);
|
||||||
|
assert_eq!(csc.get_entry_mut(1, 3), None);
|
||||||
|
assert_eq!(csc.get_entry_mut(1, 4), None);
|
||||||
|
assert_eq!(csc.get_entry_mut(2, 0), None);
|
||||||
|
assert_eq!(csc.get_entry_mut(2, 1), None);
|
||||||
|
assert_eq!(csc.get_entry_mut(2, 2), None);
|
||||||
|
assert_eq!(csc.get_entry_mut(2, 3), None);
|
||||||
|
assert_eq!(csc.get_entry_mut(2, 4), None);
|
||||||
|
|
||||||
|
// Check that out of bounds with .index_entry_mut panics
|
||||||
|
// Note: the cloning is necessary because a mutable reference is not UnwindSafe
|
||||||
|
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(0, 3); });
|
||||||
|
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(0, 4); });
|
||||||
|
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(1, 3); });
|
||||||
|
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(1, 4); });
|
||||||
|
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(2, 0); });
|
||||||
|
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(2, 1); });
|
||||||
|
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(2, 2); });
|
||||||
|
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(2, 3); });
|
||||||
|
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(2, 4); });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_matrix_col_iter() {
|
||||||
|
// Note: this is the transpose of the matrix used for the similar csr_matrix_row_iter test
|
||||||
|
// (this way the actual tests are almost identical, due to the transposed relationship
|
||||||
|
// between CSR and CSC)
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let dense = DMatrix::from_row_slice(4, 3, &[
|
||||||
|
0, 3, 0,
|
||||||
|
1, 0, 4,
|
||||||
|
2, 0, 0,
|
||||||
|
0, 0, 5,
|
||||||
|
]);
|
||||||
|
let csc = CscMatrix::from(&dense);
|
||||||
|
|
||||||
|
// Immutable iterator
|
||||||
|
{
|
||||||
|
let mut col_iter = csc.col_iter();
|
||||||
|
|
||||||
|
{
|
||||||
|
let col = col_iter.next().unwrap();
|
||||||
|
assert_eq!(col.nrows(), 4);
|
||||||
|
assert_eq!(col.nnz(), 2);
|
||||||
|
assert_eq!(col.row_indices(), &[1, 2]);
|
||||||
|
assert_eq!(col.values(), &[1, 2]);
|
||||||
|
assert_eq!(col.get_entry(0), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(col.get_entry(1), Some(SparseEntry::NonZero(&1)));
|
||||||
|
assert_eq!(col.get_entry(2), Some(SparseEntry::NonZero(&2)));
|
||||||
|
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(col.get_entry(4), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let col = col_iter.next().unwrap();
|
||||||
|
assert_eq!(col.nrows(), 4);
|
||||||
|
assert_eq!(col.nnz(), 1);
|
||||||
|
assert_eq!(col.row_indices(), &[0]);
|
||||||
|
assert_eq!(col.values(), &[3]);
|
||||||
|
assert_eq!(col.get_entry(0), Some(SparseEntry::NonZero(&3)));
|
||||||
|
assert_eq!(col.get_entry(1), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(col.get_entry(4), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let col = col_iter.next().unwrap();
|
||||||
|
assert_eq!(col.nrows(), 4);
|
||||||
|
assert_eq!(col.nnz(), 2);
|
||||||
|
assert_eq!(col.row_indices(), &[1, 3]);
|
||||||
|
assert_eq!(col.values(), &[4, 5]);
|
||||||
|
assert_eq!(col.get_entry(0), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(col.get_entry(1), Some(SparseEntry::NonZero(&4)));
|
||||||
|
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(col.get_entry(3), Some(SparseEntry::NonZero(&5)));
|
||||||
|
assert_eq!(col.get_entry(4), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(col_iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mutable iterator
|
||||||
|
{
|
||||||
|
let mut csc = csc;
|
||||||
|
let mut col_iter = csc.col_iter_mut();
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut col = col_iter.next().unwrap();
|
||||||
|
assert_eq!(col.nrows(), 4);
|
||||||
|
assert_eq!(col.nnz(), 2);
|
||||||
|
assert_eq!(col.row_indices(), &[1, 2]);
|
||||||
|
assert_eq!(col.values(), &[1, 2]);
|
||||||
|
assert_eq!(col.get_entry(0), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(col.get_entry(1), Some(SparseEntry::NonZero(&1)));
|
||||||
|
assert_eq!(col.get_entry(2), Some(SparseEntry::NonZero(&2)));
|
||||||
|
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(col.get_entry(4), None);
|
||||||
|
|
||||||
|
assert_eq!(col.values_mut(), &mut [1, 2]);
|
||||||
|
assert_eq!(
|
||||||
|
col.rows_and_values_mut(),
|
||||||
|
([1, 2].as_ref(), [1, 2].as_mut())
|
||||||
|
);
|
||||||
|
assert_eq!(col.get_entry_mut(0), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(col.get_entry_mut(1), Some(SparseEntryMut::NonZero(&mut 1)));
|
||||||
|
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2)));
|
||||||
|
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(col.get_entry_mut(4), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut col = col_iter.next().unwrap();
|
||||||
|
assert_eq!(col.nrows(), 4);
|
||||||
|
assert_eq!(col.nnz(), 1);
|
||||||
|
assert_eq!(col.row_indices(), &[0]);
|
||||||
|
assert_eq!(col.values(), &[3]);
|
||||||
|
assert_eq!(col.get_entry(0), Some(SparseEntry::NonZero(&3)));
|
||||||
|
assert_eq!(col.get_entry(1), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(col.get_entry(4), None);
|
||||||
|
|
||||||
|
assert_eq!(col.values_mut(), &mut [3]);
|
||||||
|
assert_eq!(col.rows_and_values_mut(), ([0].as_ref(), [3].as_mut()));
|
||||||
|
assert_eq!(col.get_entry_mut(0), Some(SparseEntryMut::NonZero(&mut 3)));
|
||||||
|
assert_eq!(col.get_entry_mut(1), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(col.get_entry_mut(4), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut col = col_iter.next().unwrap();
|
||||||
|
assert_eq!(col.nrows(), 4);
|
||||||
|
assert_eq!(col.nnz(), 2);
|
||||||
|
assert_eq!(col.row_indices(), &[1, 3]);
|
||||||
|
assert_eq!(col.values(), &[4, 5]);
|
||||||
|
assert_eq!(col.get_entry(0), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(col.get_entry(1), Some(SparseEntry::NonZero(&4)));
|
||||||
|
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(col.get_entry(3), Some(SparseEntry::NonZero(&5)));
|
||||||
|
assert_eq!(col.get_entry(4), None);
|
||||||
|
|
||||||
|
assert_eq!(col.values_mut(), &mut [4, 5]);
|
||||||
|
assert_eq!(
|
||||||
|
col.rows_and_values_mut(),
|
||||||
|
([1, 3].as_ref(), [4, 5].as_mut())
|
||||||
|
);
|
||||||
|
assert_eq!(col.get_entry_mut(0), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(col.get_entry_mut(1), Some(SparseEntryMut::NonZero(&mut 4)));
|
||||||
|
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5)));
|
||||||
|
assert_eq!(col.get_entry_mut(4), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(col_iter.next().is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proptest! {
|
||||||
|
#[test]
|
||||||
|
fn csc_double_transpose_is_identity(csc in csc_strategy()) {
|
||||||
|
prop_assert_eq!(csc.transpose().transpose(), csc);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_transpose_agrees_with_dense(csc in csc_strategy()) {
|
||||||
|
let dense_transpose = DMatrix::from(&csc).transpose();
|
||||||
|
let csc_transpose = csc.transpose();
|
||||||
|
prop_assert_eq!(dense_transpose, DMatrix::from(&csc_transpose));
|
||||||
|
prop_assert_eq!(csc.nnz(), csc_transpose.nnz());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_filter(
|
||||||
|
(csc, triplet_subset)
|
||||||
|
in csc_strategy()
|
||||||
|
.prop_flat_map(|matrix| {
|
||||||
|
let triplets: Vec<_> = matrix.triplet_iter().cloned_values().collect();
|
||||||
|
let subset = subsequence(triplets, 0 ..= matrix.nnz())
|
||||||
|
.prop_map(|triplet_subset| {
|
||||||
|
let set: HashSet<_> = triplet_subset.into_iter().collect();
|
||||||
|
set
|
||||||
|
});
|
||||||
|
(Just(matrix), subset)
|
||||||
|
}))
|
||||||
|
{
|
||||||
|
// We generate a CscMatrix and a HashSet corresponding to a subset of the (i, j, v)
|
||||||
|
// values in the matrix, which we use for filtering the matrix entries.
|
||||||
|
// The resulting triplets in the filtered matrix must then be exactly equal to
|
||||||
|
// the subset.
|
||||||
|
let filtered = csc.filter(|i, j, v| triplet_subset.contains(&(i, j, *v)));
|
||||||
|
let filtered_triplets: HashSet<_> = filtered
|
||||||
|
.triplet_iter()
|
||||||
|
.cloned_values()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
prop_assert_eq!(filtered_triplets, triplet_subset);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_lower_triangle_agrees_with_dense(csc in csc_strategy()) {
|
||||||
|
let csc_lower_triangle = csc.lower_triangle();
|
||||||
|
prop_assert_eq!(DMatrix::from(&csc_lower_triangle), DMatrix::from(&csc).lower_triangle());
|
||||||
|
prop_assert!(csc_lower_triangle.nnz() <= csc.nnz());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_upper_triangle_agrees_with_dense(csc in csc_strategy()) {
|
||||||
|
let csc_upper_triangle = csc.upper_triangle();
|
||||||
|
prop_assert_eq!(DMatrix::from(&csc_upper_triangle), DMatrix::from(&csc).upper_triangle());
|
||||||
|
prop_assert!(csc_upper_triangle.nnz() <= csc.nnz());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_diagonal_as_csc(csc in csc_strategy()) {
|
||||||
|
let d = csc.diagonal_as_csc();
|
||||||
|
let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect();
|
||||||
|
let csc_diagonal_entries: HashSet<_> = csc
|
||||||
|
.triplet_iter()
|
||||||
|
.cloned_values()
|
||||||
|
.filter(|&(i, j, _)| i == j)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
prop_assert_eq!(d_entries, csc_diagonal_entries);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_identity(n in 0 ..= 6usize) {
|
||||||
|
let csc = CscMatrix::<i32>::identity(n);
|
||||||
|
prop_assert_eq!(csc.nnz(), n);
|
||||||
|
prop_assert_eq!(DMatrix::from(&csc), DMatrix::identity(n, n));
|
||||||
|
}
|
||||||
|
}
|
601
nalgebra-sparse/tests/unit_tests/csr.rs
Normal file
601
nalgebra-sparse/tests/unit_tests/csr.rs
Normal file
@ -0,0 +1,601 @@
|
|||||||
|
use nalgebra::DMatrix;
|
||||||
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
use nalgebra_sparse::{SparseEntry, SparseEntryMut, SparseFormatErrorKind};
|
||||||
|
|
||||||
|
use proptest::prelude::*;
|
||||||
|
use proptest::sample::subsequence;
|
||||||
|
|
||||||
|
use crate::assert_panics;
|
||||||
|
use crate::common::csr_strategy;
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_matrix_valid_data() {
|
||||||
|
// Construct matrix from valid data and check that selected methods return results
|
||||||
|
// that agree with expectations.
|
||||||
|
|
||||||
|
{
|
||||||
|
// A CSR matrix with zero explicitly stored entries
|
||||||
|
let offsets = vec![0, 0, 0, 0];
|
||||||
|
let indices = vec![];
|
||||||
|
let values = Vec::<i32>::new();
|
||||||
|
let mut matrix = CsrMatrix::try_from_csr_data(3, 2, offsets, indices, values).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(matrix, CsrMatrix::zeros(3, 2));
|
||||||
|
|
||||||
|
assert_eq!(matrix.nrows(), 3);
|
||||||
|
assert_eq!(matrix.ncols(), 2);
|
||||||
|
assert_eq!(matrix.nnz(), 0);
|
||||||
|
assert_eq!(matrix.row_offsets(), &[0, 0, 0, 0]);
|
||||||
|
assert_eq!(matrix.col_indices(), &[]);
|
||||||
|
assert_eq!(matrix.values(), &[]);
|
||||||
|
|
||||||
|
assert!(matrix.triplet_iter().next().is_none());
|
||||||
|
assert!(matrix.triplet_iter_mut().next().is_none());
|
||||||
|
|
||||||
|
assert_eq!(matrix.row(0).ncols(), 2);
|
||||||
|
assert_eq!(matrix.row(0).nnz(), 0);
|
||||||
|
assert_eq!(matrix.row(0).col_indices(), &[]);
|
||||||
|
assert_eq!(matrix.row(0).values(), &[]);
|
||||||
|
assert_eq!(matrix.row_mut(0).ncols(), 2);
|
||||||
|
assert_eq!(matrix.row_mut(0).nnz(), 0);
|
||||||
|
assert_eq!(matrix.row_mut(0).col_indices(), &[]);
|
||||||
|
assert_eq!(matrix.row_mut(0).values(), &[]);
|
||||||
|
assert_eq!(matrix.row_mut(0).values_mut(), &[]);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.row_mut(0).cols_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(matrix.row(1).ncols(), 2);
|
||||||
|
assert_eq!(matrix.row(1).nnz(), 0);
|
||||||
|
assert_eq!(matrix.row(1).col_indices(), &[]);
|
||||||
|
assert_eq!(matrix.row(1).values(), &[]);
|
||||||
|
assert_eq!(matrix.row_mut(1).ncols(), 2);
|
||||||
|
assert_eq!(matrix.row_mut(1).nnz(), 0);
|
||||||
|
assert_eq!(matrix.row_mut(1).col_indices(), &[]);
|
||||||
|
assert_eq!(matrix.row_mut(1).values(), &[]);
|
||||||
|
assert_eq!(matrix.row_mut(1).values_mut(), &[]);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.row_mut(1).cols_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(matrix.row(2).ncols(), 2);
|
||||||
|
assert_eq!(matrix.row(2).nnz(), 0);
|
||||||
|
assert_eq!(matrix.row(2).col_indices(), &[]);
|
||||||
|
assert_eq!(matrix.row(2).values(), &[]);
|
||||||
|
assert_eq!(matrix.row_mut(2).ncols(), 2);
|
||||||
|
assert_eq!(matrix.row_mut(2).nnz(), 0);
|
||||||
|
assert_eq!(matrix.row_mut(2).col_indices(), &[]);
|
||||||
|
assert_eq!(matrix.row_mut(2).values(), &[]);
|
||||||
|
assert_eq!(matrix.row_mut(2).values_mut(), &[]);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.row_mut(2).cols_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(matrix.get_row(3).is_none());
|
||||||
|
assert!(matrix.get_row_mut(3).is_none());
|
||||||
|
|
||||||
|
let (offsets, indices, values) = matrix.disassemble();
|
||||||
|
|
||||||
|
assert_eq!(offsets, vec![0, 0, 0, 0]);
|
||||||
|
assert_eq!(indices, vec![]);
|
||||||
|
assert_eq!(values, vec![]);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// An arbitrary CSR matrix
|
||||||
|
let offsets = vec![0, 2, 2, 5];
|
||||||
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let mut matrix =
|
||||||
|
CsrMatrix::try_from_csr_data(3, 6, offsets.clone(), indices.clone(), values.clone())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(matrix.nrows(), 3);
|
||||||
|
assert_eq!(matrix.ncols(), 6);
|
||||||
|
assert_eq!(matrix.nnz(), 5);
|
||||||
|
assert_eq!(matrix.row_offsets(), &[0, 2, 2, 5]);
|
||||||
|
assert_eq!(matrix.col_indices(), &[0, 5, 1, 2, 3]);
|
||||||
|
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
|
||||||
|
|
||||||
|
let expected_triplets = vec![(0, 0, 0), (0, 5, 1), (2, 1, 2), (2, 2, 3), (2, 3, 4)];
|
||||||
|
assert_eq!(
|
||||||
|
matrix
|
||||||
|
.triplet_iter()
|
||||||
|
.map(|(i, j, v)| (i, j, *v))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
expected_triplets
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
matrix
|
||||||
|
.triplet_iter_mut()
|
||||||
|
.map(|(i, j, v)| (i, j, *v))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
expected_triplets
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(matrix.row(0).ncols(), 6);
|
||||||
|
assert_eq!(matrix.row(0).nnz(), 2);
|
||||||
|
assert_eq!(matrix.row(0).col_indices(), &[0, 5]);
|
||||||
|
assert_eq!(matrix.row(0).values(), &[0, 1]);
|
||||||
|
assert_eq!(matrix.row_mut(0).ncols(), 6);
|
||||||
|
assert_eq!(matrix.row_mut(0).nnz(), 2);
|
||||||
|
assert_eq!(matrix.row_mut(0).col_indices(), &[0, 5]);
|
||||||
|
assert_eq!(matrix.row_mut(0).values(), &[0, 1]);
|
||||||
|
assert_eq!(matrix.row_mut(0).values_mut(), &[0, 1]);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.row_mut(0).cols_and_values_mut(),
|
||||||
|
([0, 5].as_ref(), [0, 1].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(matrix.row(1).ncols(), 6);
|
||||||
|
assert_eq!(matrix.row(1).nnz(), 0);
|
||||||
|
assert_eq!(matrix.row(1).col_indices(), &[]);
|
||||||
|
assert_eq!(matrix.row(1).values(), &[]);
|
||||||
|
assert_eq!(matrix.row_mut(1).ncols(), 6);
|
||||||
|
assert_eq!(matrix.row_mut(1).nnz(), 0);
|
||||||
|
assert_eq!(matrix.row_mut(1).col_indices(), &[]);
|
||||||
|
assert_eq!(matrix.row_mut(1).values(), &[]);
|
||||||
|
assert_eq!(matrix.row_mut(1).values_mut(), &[]);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.row_mut(1).cols_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(matrix.row(2).ncols(), 6);
|
||||||
|
assert_eq!(matrix.row(2).nnz(), 3);
|
||||||
|
assert_eq!(matrix.row(2).col_indices(), &[1, 2, 3]);
|
||||||
|
assert_eq!(matrix.row(2).values(), &[2, 3, 4]);
|
||||||
|
assert_eq!(matrix.row_mut(2).ncols(), 6);
|
||||||
|
assert_eq!(matrix.row_mut(2).nnz(), 3);
|
||||||
|
assert_eq!(matrix.row_mut(2).col_indices(), &[1, 2, 3]);
|
||||||
|
assert_eq!(matrix.row_mut(2).values(), &[2, 3, 4]);
|
||||||
|
assert_eq!(matrix.row_mut(2).values_mut(), &[2, 3, 4]);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.row_mut(2).cols_and_values_mut(),
|
||||||
|
([1, 2, 3].as_ref(), [2, 3, 4].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(matrix.get_row(3).is_none());
|
||||||
|
assert!(matrix.get_row_mut(3).is_none());
|
||||||
|
|
||||||
|
let (offsets2, indices2, values2) = matrix.disassemble();
|
||||||
|
|
||||||
|
assert_eq!(offsets2, offsets);
|
||||||
|
assert_eq!(indices2, indices);
|
||||||
|
assert_eq!(values2, values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_matrix_try_from_invalid_csr_data() {
|
||||||
|
{
|
||||||
|
// Empty offset array (invalid length)
|
||||||
|
let matrix = CsrMatrix::try_from_csr_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Offset array invalid length for arbitrary data
|
||||||
|
let offsets = vec![0, 3, 5];
|
||||||
|
let indices = vec![0, 1, 2, 3, 5];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
|
||||||
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Invalid first entry in offsets array
|
||||||
|
let offsets = vec![1, 2, 2, 5];
|
||||||
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Invalid last entry in offsets array
|
||||||
|
let offsets = vec![0, 2, 2, 4];
|
||||||
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Invalid length of offsets array
|
||||||
|
let offsets = vec![0, 2, 2];
|
||||||
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Nonmonotonic offsets
|
||||||
|
let offsets = vec![0, 3, 2, 5];
|
||||||
|
let indices = vec![0, 1, 2, 3, 4];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Nonmonotonic minor indices
|
||||||
|
let offsets = vec![0, 2, 2, 5];
|
||||||
|
let indices = vec![0, 2, 3, 1, 4];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Minor index out of bounds
|
||||||
|
let offsets = vec![0, 2, 2, 5];
|
||||||
|
let indices = vec![0, 6, 1, 2, 3];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Duplicate entry
|
||||||
|
let offsets = vec![0, 2, 2, 5];
|
||||||
|
let indices = vec![0, 5, 2, 2, 3];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::DuplicateEntry
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_disassemble_avoids_clone_when_owned() {
|
||||||
|
// Test that disassemble avoids cloning the sparsity pattern when it holds the sole reference
|
||||||
|
// to the pattern. We do so by checking that the pointer to the data is unchanged.
|
||||||
|
|
||||||
|
let offsets = vec![0, 2, 2, 5];
|
||||||
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let offsets_ptr = offsets.as_ptr();
|
||||||
|
let indices_ptr = indices.as_ptr();
|
||||||
|
let values_ptr = values.as_ptr();
|
||||||
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values).unwrap();
|
||||||
|
|
||||||
|
let (offsets, indices, values) = matrix.disassemble();
|
||||||
|
assert_eq!(offsets.as_ptr(), offsets_ptr);
|
||||||
|
assert_eq!(indices.as_ptr(), indices_ptr);
|
||||||
|
assert_eq!(values.as_ptr(), values_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rustfmt makes this test much harder to read by expanding some of the one-liners to 4-liners,
|
||||||
|
// so for now we skip rustfmt...
|
||||||
|
#[rustfmt::skip]
|
||||||
|
#[test]
|
||||||
|
fn csr_matrix_get_index_entry() {
|
||||||
|
// Test .get_entry(_mut) and .index_entry(_mut) methods
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let dense = DMatrix::from_row_slice(2, 3, &[
|
||||||
|
1, 0, 3,
|
||||||
|
0, 5, 6
|
||||||
|
]);
|
||||||
|
let csr = CsrMatrix::from(&dense);
|
||||||
|
|
||||||
|
assert_eq!(csr.get_entry(0, 0), Some(SparseEntry::NonZero(&1)));
|
||||||
|
assert_eq!(csr.index_entry(0, 0), SparseEntry::NonZero(&1));
|
||||||
|
assert_eq!(csr.get_entry(0, 1), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(csr.index_entry(0, 1), SparseEntry::Zero);
|
||||||
|
assert_eq!(csr.get_entry(0, 2), Some(SparseEntry::NonZero(&3)));
|
||||||
|
assert_eq!(csr.index_entry(0, 2), SparseEntry::NonZero(&3));
|
||||||
|
assert_eq!(csr.get_entry(1, 0), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(csr.index_entry(1, 0), SparseEntry::Zero);
|
||||||
|
assert_eq!(csr.get_entry(1, 1), Some(SparseEntry::NonZero(&5)));
|
||||||
|
assert_eq!(csr.index_entry(1, 1), SparseEntry::NonZero(&5));
|
||||||
|
assert_eq!(csr.get_entry(1, 2), Some(SparseEntry::NonZero(&6)));
|
||||||
|
assert_eq!(csr.index_entry(1, 2), SparseEntry::NonZero(&6));
|
||||||
|
|
||||||
|
// Check some out of bounds with .get_entry
|
||||||
|
assert_eq!(csr.get_entry(0, 3), None);
|
||||||
|
assert_eq!(csr.get_entry(0, 4), None);
|
||||||
|
assert_eq!(csr.get_entry(1, 3), None);
|
||||||
|
assert_eq!(csr.get_entry(1, 4), None);
|
||||||
|
assert_eq!(csr.get_entry(2, 0), None);
|
||||||
|
assert_eq!(csr.get_entry(2, 1), None);
|
||||||
|
assert_eq!(csr.get_entry(2, 2), None);
|
||||||
|
assert_eq!(csr.get_entry(2, 3), None);
|
||||||
|
assert_eq!(csr.get_entry(2, 4), None);
|
||||||
|
|
||||||
|
// Check that out of bounds with .index_entry panics
|
||||||
|
assert_panics!(csr.index_entry(0, 3));
|
||||||
|
assert_panics!(csr.index_entry(0, 4));
|
||||||
|
assert_panics!(csr.index_entry(1, 3));
|
||||||
|
assert_panics!(csr.index_entry(1, 4));
|
||||||
|
assert_panics!(csr.index_entry(2, 0));
|
||||||
|
assert_panics!(csr.index_entry(2, 1));
|
||||||
|
assert_panics!(csr.index_entry(2, 2));
|
||||||
|
assert_panics!(csr.index_entry(2, 3));
|
||||||
|
assert_panics!(csr.index_entry(2, 4));
|
||||||
|
|
||||||
|
{
|
||||||
|
// Check mutable versions of the above functions
|
||||||
|
let mut csr = csr;
|
||||||
|
|
||||||
|
assert_eq!(csr.get_entry_mut(0, 0), Some(SparseEntryMut::NonZero(&mut 1)));
|
||||||
|
assert_eq!(csr.index_entry_mut(0, 0), SparseEntryMut::NonZero(&mut 1));
|
||||||
|
assert_eq!(csr.get_entry_mut(0, 1), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(csr.index_entry_mut(0, 1), SparseEntryMut::Zero);
|
||||||
|
assert_eq!(csr.get_entry_mut(0, 2), Some(SparseEntryMut::NonZero(&mut 3)));
|
||||||
|
assert_eq!(csr.index_entry_mut(0, 2), SparseEntryMut::NonZero(&mut 3));
|
||||||
|
assert_eq!(csr.get_entry_mut(1, 0), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(csr.index_entry_mut(1, 0), SparseEntryMut::Zero);
|
||||||
|
assert_eq!(csr.get_entry_mut(1, 1), Some(SparseEntryMut::NonZero(&mut 5)));
|
||||||
|
assert_eq!(csr.index_entry_mut(1, 1), SparseEntryMut::NonZero(&mut 5));
|
||||||
|
assert_eq!(csr.get_entry_mut(1, 2), Some(SparseEntryMut::NonZero(&mut 6)));
|
||||||
|
assert_eq!(csr.index_entry_mut(1, 2), SparseEntryMut::NonZero(&mut 6));
|
||||||
|
|
||||||
|
// Check some out of bounds with .get_entry_mut
|
||||||
|
assert_eq!(csr.get_entry_mut(0, 3), None);
|
||||||
|
assert_eq!(csr.get_entry_mut(0, 4), None);
|
||||||
|
assert_eq!(csr.get_entry_mut(1, 3), None);
|
||||||
|
assert_eq!(csr.get_entry_mut(1, 4), None);
|
||||||
|
assert_eq!(csr.get_entry_mut(2, 0), None);
|
||||||
|
assert_eq!(csr.get_entry_mut(2, 1), None);
|
||||||
|
assert_eq!(csr.get_entry_mut(2, 2), None);
|
||||||
|
assert_eq!(csr.get_entry_mut(2, 3), None);
|
||||||
|
assert_eq!(csr.get_entry_mut(2, 4), None);
|
||||||
|
|
||||||
|
// Check that out of bounds with .index_entry_mut panics
|
||||||
|
// Note: the cloning is necessary because a mutable reference is not UnwindSafe
|
||||||
|
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(0, 3); });
|
||||||
|
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(0, 4); });
|
||||||
|
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(1, 3); });
|
||||||
|
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(1, 4); });
|
||||||
|
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(2, 0); });
|
||||||
|
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(2, 1); });
|
||||||
|
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(2, 2); });
|
||||||
|
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(2, 3); });
|
||||||
|
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(2, 4); });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_matrix_row_iter() {
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let dense = DMatrix::from_row_slice(3, 4, &[
|
||||||
|
0, 1, 2, 0,
|
||||||
|
3, 0, 0, 0,
|
||||||
|
0, 4, 0, 5
|
||||||
|
]);
|
||||||
|
let csr = CsrMatrix::from(&dense);
|
||||||
|
|
||||||
|
// Immutable iterator
|
||||||
|
{
|
||||||
|
let mut row_iter = csr.row_iter();
|
||||||
|
|
||||||
|
{
|
||||||
|
let row = row_iter.next().unwrap();
|
||||||
|
assert_eq!(row.ncols(), 4);
|
||||||
|
assert_eq!(row.nnz(), 2);
|
||||||
|
assert_eq!(row.col_indices(), &[1, 2]);
|
||||||
|
assert_eq!(row.values(), &[1, 2]);
|
||||||
|
assert_eq!(row.get_entry(0), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(row.get_entry(1), Some(SparseEntry::NonZero(&1)));
|
||||||
|
assert_eq!(row.get_entry(2), Some(SparseEntry::NonZero(&2)));
|
||||||
|
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(row.get_entry(4), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let row = row_iter.next().unwrap();
|
||||||
|
assert_eq!(row.ncols(), 4);
|
||||||
|
assert_eq!(row.nnz(), 1);
|
||||||
|
assert_eq!(row.col_indices(), &[0]);
|
||||||
|
assert_eq!(row.values(), &[3]);
|
||||||
|
assert_eq!(row.get_entry(0), Some(SparseEntry::NonZero(&3)));
|
||||||
|
assert_eq!(row.get_entry(1), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(row.get_entry(4), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let row = row_iter.next().unwrap();
|
||||||
|
assert_eq!(row.ncols(), 4);
|
||||||
|
assert_eq!(row.nnz(), 2);
|
||||||
|
assert_eq!(row.col_indices(), &[1, 3]);
|
||||||
|
assert_eq!(row.values(), &[4, 5]);
|
||||||
|
assert_eq!(row.get_entry(0), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(row.get_entry(1), Some(SparseEntry::NonZero(&4)));
|
||||||
|
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(row.get_entry(3), Some(SparseEntry::NonZero(&5)));
|
||||||
|
assert_eq!(row.get_entry(4), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(row_iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mutable iterator
|
||||||
|
{
|
||||||
|
let mut csr = csr;
|
||||||
|
let mut row_iter = csr.row_iter_mut();
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut row = row_iter.next().unwrap();
|
||||||
|
assert_eq!(row.ncols(), 4);
|
||||||
|
assert_eq!(row.nnz(), 2);
|
||||||
|
assert_eq!(row.col_indices(), &[1, 2]);
|
||||||
|
assert_eq!(row.values(), &[1, 2]);
|
||||||
|
assert_eq!(row.get_entry(0), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(row.get_entry(1), Some(SparseEntry::NonZero(&1)));
|
||||||
|
assert_eq!(row.get_entry(2), Some(SparseEntry::NonZero(&2)));
|
||||||
|
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(row.get_entry(4), None);
|
||||||
|
|
||||||
|
assert_eq!(row.values_mut(), &mut [1, 2]);
|
||||||
|
assert_eq!(
|
||||||
|
row.cols_and_values_mut(),
|
||||||
|
([1, 2].as_ref(), [1, 2].as_mut())
|
||||||
|
);
|
||||||
|
assert_eq!(row.get_entry_mut(0), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(row.get_entry_mut(1), Some(SparseEntryMut::NonZero(&mut 1)));
|
||||||
|
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2)));
|
||||||
|
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(row.get_entry_mut(4), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut row = row_iter.next().unwrap();
|
||||||
|
assert_eq!(row.ncols(), 4);
|
||||||
|
assert_eq!(row.nnz(), 1);
|
||||||
|
assert_eq!(row.col_indices(), &[0]);
|
||||||
|
assert_eq!(row.values(), &[3]);
|
||||||
|
assert_eq!(row.get_entry(0), Some(SparseEntry::NonZero(&3)));
|
||||||
|
assert_eq!(row.get_entry(1), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(row.get_entry(4), None);
|
||||||
|
|
||||||
|
assert_eq!(row.values_mut(), &mut [3]);
|
||||||
|
assert_eq!(row.cols_and_values_mut(), ([0].as_ref(), [3].as_mut()));
|
||||||
|
assert_eq!(row.get_entry_mut(0), Some(SparseEntryMut::NonZero(&mut 3)));
|
||||||
|
assert_eq!(row.get_entry_mut(1), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(row.get_entry_mut(4), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut row = row_iter.next().unwrap();
|
||||||
|
assert_eq!(row.ncols(), 4);
|
||||||
|
assert_eq!(row.nnz(), 2);
|
||||||
|
assert_eq!(row.col_indices(), &[1, 3]);
|
||||||
|
assert_eq!(row.values(), &[4, 5]);
|
||||||
|
assert_eq!(row.get_entry(0), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(row.get_entry(1), Some(SparseEntry::NonZero(&4)));
|
||||||
|
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
|
||||||
|
assert_eq!(row.get_entry(3), Some(SparseEntry::NonZero(&5)));
|
||||||
|
assert_eq!(row.get_entry(4), None);
|
||||||
|
|
||||||
|
assert_eq!(row.values_mut(), &mut [4, 5]);
|
||||||
|
assert_eq!(
|
||||||
|
row.cols_and_values_mut(),
|
||||||
|
([1, 3].as_ref(), [4, 5].as_mut())
|
||||||
|
);
|
||||||
|
assert_eq!(row.get_entry_mut(0), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(row.get_entry_mut(1), Some(SparseEntryMut::NonZero(&mut 4)));
|
||||||
|
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||||
|
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5)));
|
||||||
|
assert_eq!(row.get_entry_mut(4), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(row_iter.next().is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proptest! {
|
||||||
|
#[test]
|
||||||
|
fn csr_double_transpose_is_identity(csr in csr_strategy()) {
|
||||||
|
prop_assert_eq!(csr.transpose().transpose(), csr);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_transpose_agrees_with_dense(csr in csr_strategy()) {
|
||||||
|
let dense_transpose = DMatrix::from(&csr).transpose();
|
||||||
|
let csr_transpose = csr.transpose();
|
||||||
|
prop_assert_eq!(dense_transpose, DMatrix::from(&csr_transpose));
|
||||||
|
prop_assert_eq!(csr.nnz(), csr_transpose.nnz());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_filter(
|
||||||
|
(csr, triplet_subset)
|
||||||
|
in csr_strategy()
|
||||||
|
.prop_flat_map(|matrix| {
|
||||||
|
let triplets: Vec<_> = matrix.triplet_iter().cloned_values().collect();
|
||||||
|
let subset = subsequence(triplets, 0 ..= matrix.nnz())
|
||||||
|
.prop_map(|triplet_subset| {
|
||||||
|
let set: HashSet<_> = triplet_subset.into_iter().collect();
|
||||||
|
set
|
||||||
|
});
|
||||||
|
(Just(matrix), subset)
|
||||||
|
}))
|
||||||
|
{
|
||||||
|
// We generate a CsrMatrix and a HashSet corresponding to a subset of the (i, j, v)
|
||||||
|
// values in the matrix, which we use for filtering the matrix entries.
|
||||||
|
// The resulting triplets in the filtered matrix must then be exactly equal to
|
||||||
|
// the subset.
|
||||||
|
let filtered = csr.filter(|i, j, v| triplet_subset.contains(&(i, j, *v)));
|
||||||
|
let filtered_triplets: HashSet<_> = filtered
|
||||||
|
.triplet_iter()
|
||||||
|
.cloned_values()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
prop_assert_eq!(filtered_triplets, triplet_subset);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_lower_triangle_agrees_with_dense(csr in csr_strategy()) {
|
||||||
|
let csr_lower_triangle = csr.lower_triangle();
|
||||||
|
prop_assert_eq!(DMatrix::from(&csr_lower_triangle), DMatrix::from(&csr).lower_triangle());
|
||||||
|
prop_assert!(csr_lower_triangle.nnz() <= csr.nnz());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_upper_triangle_agrees_with_dense(csr in csr_strategy()) {
|
||||||
|
let csr_upper_triangle = csr.upper_triangle();
|
||||||
|
prop_assert_eq!(DMatrix::from(&csr_upper_triangle), DMatrix::from(&csr).upper_triangle());
|
||||||
|
prop_assert!(csr_upper_triangle.nnz() <= csr.nnz());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_diagonal_as_csr(csr in csr_strategy()) {
|
||||||
|
let d = csr.diagonal_as_csr();
|
||||||
|
let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect();
|
||||||
|
let csr_diagonal_entries: HashSet<_> = csr
|
||||||
|
.triplet_iter()
|
||||||
|
.cloned_values()
|
||||||
|
.filter(|&(i, j, _)| i == j)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
prop_assert_eq!(d_entries, csr_diagonal_entries);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_identity(n in 0 ..= 6usize) {
|
||||||
|
let csr = CsrMatrix::<i32>::identity(n);
|
||||||
|
prop_assert_eq!(csr.nnz(), n);
|
||||||
|
prop_assert_eq!(DMatrix::from(&csr), DMatrix::identity(n, n));
|
||||||
|
}
|
||||||
|
}
|
8
nalgebra-sparse/tests/unit_tests/mod.rs
Normal file
8
nalgebra-sparse/tests/unit_tests/mod.rs
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
mod cholesky;
|
||||||
|
mod convert_serial;
|
||||||
|
mod coo;
|
||||||
|
mod csc;
|
||||||
|
mod csr;
|
||||||
|
mod ops;
|
||||||
|
mod pattern;
|
||||||
|
mod proptest;
|
14
nalgebra-sparse/tests/unit_tests/ops.proptest-regressions
Normal file
14
nalgebra-sparse/tests/unit_tests/ops.proptest-regressions
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# Seeds for failure cases proptest has generated in the past. It is
|
||||||
|
# automatically read and these particular cases re-run before any
|
||||||
|
# novel cases are generated.
|
||||||
|
#
|
||||||
|
# It is recommended to check this file in to source control so that
|
||||||
|
# everyone who runs the test benefits from these saved cases.
|
||||||
|
cc 6748ea4ac9523fcc4dd8327b27c6818f8df10eb2042774f59a6e3fa3205dbcbd # shrinks to (beta, alpha, (c, a, b)) = (0, -1, (Matrix { data: VecStorage { data: [0, 0, 0, 0, 0, 1, 5, -4, 2], nrows: Dynamic { value: 3 }, ncols: Dynamic { value: 3 } } }, CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 2, 2, 2], minor_indices: [0, 1], minor_dim: 5 }, values: [-5, -2] }, Matrix { data: VecStorage { data: [4, -2, -3, -3, -5, 3, 5, 1, -4, -4, 3, 5, 5, 5, -3], nrows: Dynamic { value: 5 }, ncols: Dynamic { value: 3 } } }))
|
||||||
|
cc dcf67ab7b8febf109cfa58ee0f082b9f7c23d6ad0df2e28dc99984deeb6b113a # shrinks to (beta, alpha, (c, a, b)) = (0, 0, (Matrix { data: VecStorage { data: [0, -1], nrows: Dynamic { value: 1 }, ncols: Dynamic { value: 2 } } }, CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 0], minor_indices: [], minor_dim: 4 }, values: [] }, Matrix { data: VecStorage { data: [3, 1, 1, 0, 0, 3, -5, -3], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 2 } } }))
|
||||||
|
cc dbaef9886eaad28be7cd48326b857f039d695bc0b19e9ada3304e812e984d2c3 # shrinks to (beta, alpha, (c, a, b)) = (0, -1, (Matrix { data: VecStorage { data: [1], nrows: Dynamic { value: 1 }, ncols: Dynamic { value: 1 } } }, CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 0], minor_indices: [], minor_dim: 0 }, values: [] }, Matrix { data: VecStorage { data: [], nrows: Dynamic { value: 0 }, ncols: Dynamic { value: 1 } } }))
|
||||||
|
cc 99e312beb498ffa79194f41501ea312dce1911878eba131282904ac97205aaa9 # shrinks to SpmmCsrDenseArgs { c, beta, alpha, trans_a, a, trans_b, b } = SpmmCsrDenseArgs { c: Matrix { data: VecStorage { data: [-1, 4, -1, -4, 2, 1, 4, -2, 1, 3, -2, 5], nrows: Dynamic { value: 2 }, ncols: Dynamic { value: 6 } } }, beta: 0, alpha: 0, trans_a: Transpose, a: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 1, 1, 1, 1, 1, 1], minor_indices: [0], minor_dim: 2 }, values: [0] }, trans_b: Transpose, b: Matrix { data: VecStorage { data: [-1, 1, 0, -5, 4, -5, 2, 2, 4, -4, -3, -1, 1, -1, 0, 1, -3, 4, -5, 0, 1, -5, 0, 1, 1, -3, 5, 3, 5, -3, -5, 3, -1, -4, -4, -3], nrows: Dynamic { value: 6 }, ncols: Dynamic { value: 6 } } } }
|
||||||
|
cc bf74259df2db6eda24eb42098e57ea1c604bb67d6d0023fa308c321027b53a43 # shrinks to (alpha, beta, c, a, b, trans_a, trans_b) = (0, 0, Matrix { data: VecStorage { data: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 5 } } }, CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 3, 6, 9, 12], minor_indices: [0, 1, 3, 1, 2, 3, 0, 1, 2, 1, 2, 3], minor_dim: 4 }, values: [-3, 3, -3, 1, -3, 0, 2, 1, 3, 0, -4, -1] }, Matrix { data: VecStorage { data: [3, 1, 4, -5, 5, -2, -5, -1, 1, -1, 3, -3, -2, 4, 2, -1, -1, 3, -5, 5], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 5 } } }, NoTranspose, NoTranspose)
|
||||||
|
cc cbd6dac45a2f610e10cf4c15d4614cdbf7dfedbfcd733e4cc65c2e79829d14b3 # shrinks to SpmmCsrArgs { c, beta, alpha, trans_a, a, trans_b, b } = SpmmCsrArgs { c: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 0, 1, 1, 1, 1], minor_indices: [0], minor_dim: 1 }, values: [0] }, beta: 0, alpha: 1, trans_a: Transpose(true), a: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 0, 0, 1, 1, 1], minor_indices: [1], minor_dim: 5 }, values: [-1] }, trans_b: Transpose(true), b: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 2], minor_indices: [2, 4], minor_dim: 5 }, values: [-1, 0] } }
|
||||||
|
cc 8af78e2e41087743c8696c4d5563d59464f284662ccf85efc81ac56747d528bb # shrinks to (a, b) = (CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 6, 12, 18, 24, 30, 33], minor_indices: [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 1, 2, 5], minor_dim: 6 }, values: [0.4566433975117654, -0.5109683327713039, 0.0, -3.276901622678194, 0.0, -2.2065487385437095, 0.0, -0.42643054427847016, -2.9232369281581234, 0.0, 1.2913925579441763, 0.0, -1.4073766622090917, -4.795473113569459, 4.681765156869446, -0.821162215887913, 3.0315816068414794, -3.3986924718213407, -3.498903007282241, -3.1488953408335236, 3.458104636152161, -4.774694888508124, 2.603884664757498, 0.0, 0.0, -3.2650988857765535, 4.26699442646613, 0.0, -0.012223422086023561, 3.6899095325779285, -1.4264458042247958, 0.0, 3.4849193883471266] } }, Matrix { data: VecStorage { data: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.9513896933988457, -4.426942420881461, 0.0, 0.0, 0.0, -0.28264084049240257], nrows: Dynamic { value: 6 }, ncols: Dynamic { value: 2 } } })
|
||||||
|
cc a4effd988fe352146fca365875e108ecf4f7d41f6ad54683e923ca6ce712e5d0 # shrinks to (a, b) = (CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 5, 11, 17, 22, 27, 31], minor_indices: [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 3, 4, 5, 1, 2, 3, 4, 5, 0, 1, 3, 5], minor_dim: 6 }, values: [-2.24935510943371, -2.2288203680206227, 0.0, -1.029740125494273, 0.0, 0.0, 0.22632926934348507, -0.9123245943877407, 0.0, 3.8564332876991827, 0.0, 0.0, 0.0, -0.8235065737081717, 1.9337984046721566, 0.11003468246027737, -3.422112890579867, -3.7824068893569196, 0.0, -0.021700572247226546, -4.914783069982362, 0.6227245544506541, 0.0, 0.0, -4.411368879922364, -0.00013623178651567258, -2.613658177661417, -2.2783292441548637, 0.0, 1.351859435890189, -0.021345159183605134] } }, Matrix { data: VecStorage { data: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -4.519417607973404, 0.0, 0.0, 0.0, -0.21238483334481817], nrows: Dynamic { value: 6 }, ncols: Dynamic { value: 3 } } })
|
1235
nalgebra-sparse/tests/unit_tests/ops.rs
Normal file
1235
nalgebra-sparse/tests/unit_tests/ops.rs
Normal file
File diff suppressed because it is too large
Load Diff
154
nalgebra-sparse/tests/unit_tests/pattern.rs
Normal file
154
nalgebra-sparse/tests/unit_tests/pattern.rs
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
use nalgebra_sparse::pattern::{SparsityPattern, SparsityPatternFormatError};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sparsity_pattern_valid_data() {
|
||||||
|
// Construct pattern from valid data and check that selected methods return results
|
||||||
|
// that agree with expectations.
|
||||||
|
|
||||||
|
{
|
||||||
|
// A pattern with zero explicitly stored entries
|
||||||
|
let pattern =
|
||||||
|
SparsityPattern::try_from_offsets_and_indices(3, 2, vec![0, 0, 0, 0], Vec::new())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(pattern.major_dim(), 3);
|
||||||
|
assert_eq!(pattern.minor_dim(), 2);
|
||||||
|
assert_eq!(pattern.nnz(), 0);
|
||||||
|
assert_eq!(pattern.major_offsets(), &[0, 0, 0, 0]);
|
||||||
|
assert_eq!(pattern.minor_indices(), &[]);
|
||||||
|
assert_eq!(pattern.lane(0), &[]);
|
||||||
|
assert_eq!(pattern.lane(1), &[]);
|
||||||
|
assert_eq!(pattern.lane(2), &[]);
|
||||||
|
assert!(pattern.entries().next().is_none());
|
||||||
|
|
||||||
|
assert_eq!(pattern, SparsityPattern::zeros(3, 2));
|
||||||
|
|
||||||
|
let (offsets, indices) = pattern.disassemble();
|
||||||
|
assert_eq!(offsets, vec![0, 0, 0, 0]);
|
||||||
|
assert_eq!(indices, vec![]);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Arbitrary pattern
|
||||||
|
let offsets = vec![0, 2, 2, 5];
|
||||||
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
|
let pattern =
|
||||||
|
SparsityPattern::try_from_offsets_and_indices(3, 6, offsets.clone(), indices.clone())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(pattern.major_dim(), 3);
|
||||||
|
assert_eq!(pattern.minor_dim(), 6);
|
||||||
|
assert_eq!(pattern.major_offsets(), offsets.as_slice());
|
||||||
|
assert_eq!(pattern.minor_indices(), indices.as_slice());
|
||||||
|
assert_eq!(pattern.nnz(), 5);
|
||||||
|
assert_eq!(pattern.lane(0), &[0, 5]);
|
||||||
|
assert_eq!(pattern.lane(1), &[]);
|
||||||
|
assert_eq!(pattern.lane(2), &[1, 2, 3]);
|
||||||
|
assert_eq!(
|
||||||
|
pattern.entries().collect::<Vec<_>>(),
|
||||||
|
vec![(0, 0), (0, 5), (2, 1), (2, 2), (2, 3)]
|
||||||
|
);
|
||||||
|
|
||||||
|
let (offsets2, indices2) = pattern.disassemble();
|
||||||
|
assert_eq!(offsets2, offsets);
|
||||||
|
assert_eq!(indices2, indices);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sparsity_pattern_try_from_invalid_data() {
|
||||||
|
{
|
||||||
|
// Empty offset array (invalid length)
|
||||||
|
let pattern = SparsityPattern::try_from_offsets_and_indices(0, 0, Vec::new(), Vec::new());
|
||||||
|
assert_eq!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Offset array invalid length for arbitrary data
|
||||||
|
let offsets = vec![0, 3, 5];
|
||||||
|
let indices = vec![0, 1, 2, 3, 5];
|
||||||
|
|
||||||
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
|
assert!(matches!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Invalid first entry in offsets array
|
||||||
|
let offsets = vec![1, 2, 2, 5];
|
||||||
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
|
assert!(matches!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::InvalidOffsetFirstLast)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Invalid last entry in offsets array
|
||||||
|
let offsets = vec![0, 2, 2, 4];
|
||||||
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
|
assert!(matches!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::InvalidOffsetFirstLast)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Invalid length of offsets array
|
||||||
|
let offsets = vec![0, 2, 2];
|
||||||
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
|
assert!(matches!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Nonmonotonic offsets
|
||||||
|
let offsets = vec![0, 3, 2, 5];
|
||||||
|
let indices = vec![0, 1, 2, 3, 4];
|
||||||
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
|
assert_eq!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::NonmonotonicOffsets)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Nonmonotonic minor indices
|
||||||
|
let offsets = vec![0, 2, 2, 5];
|
||||||
|
let indices = vec![0, 2, 3, 1, 4];
|
||||||
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
|
assert_eq!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::NonmonotonicMinorIndices)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Minor index out of bounds
|
||||||
|
let offsets = vec![0, 2, 2, 5];
|
||||||
|
let indices = vec![0, 6, 1, 2, 3];
|
||||||
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
|
assert_eq!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::MinorIndexOutOfBounds)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Duplicate entry
|
||||||
|
let offsets = vec![0, 2, 2, 5];
|
||||||
|
let indices = vec![0, 5, 2, 2, 3];
|
||||||
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
|
assert_eq!(pattern, Err(SparsityPatternFormatError::DuplicateEntry));
|
||||||
|
}
|
||||||
|
}
|
247
nalgebra-sparse/tests/unit_tests/proptest.rs
Normal file
247
nalgebra-sparse/tests/unit_tests/proptest.rs
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
#[test]
|
||||||
|
#[ignore]
|
||||||
|
fn coo_no_duplicates_generates_admissible_matrices() {
|
||||||
|
//TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "slow-tests")]
|
||||||
|
mod slow {
|
||||||
|
use nalgebra::DMatrix;
|
||||||
|
use nalgebra_sparse::proptest::{
|
||||||
|
coo_no_duplicates, coo_with_duplicates, csc, csr, sparsity_pattern,
|
||||||
|
};
|
||||||
|
|
||||||
|
use itertools::Itertools;
|
||||||
|
use proptest::strategy::ValueTree;
|
||||||
|
use proptest::test_runner::TestRunner;
|
||||||
|
|
||||||
|
use proptest::prelude::*;
|
||||||
|
|
||||||
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::iter::repeat;
|
||||||
|
use std::ops::RangeInclusive;
|
||||||
|
|
||||||
|
fn generate_all_possible_matrices(
|
||||||
|
value_range: RangeInclusive<i32>,
|
||||||
|
rows_range: RangeInclusive<usize>,
|
||||||
|
cols_range: RangeInclusive<usize>,
|
||||||
|
) -> HashSet<DMatrix<i32>> {
|
||||||
|
// Enumerate all possible combinations
|
||||||
|
let mut all_combinations = HashSet::new();
|
||||||
|
for nrows in rows_range {
|
||||||
|
for ncols in cols_range.clone() {
|
||||||
|
// For the given number of rows and columns
|
||||||
|
let n_values = nrows * ncols;
|
||||||
|
|
||||||
|
if n_values == 0 {
|
||||||
|
// If we have zero rows or columns, the set of matrices with the given
|
||||||
|
// rows and columns is a single element: an empty matrix
|
||||||
|
all_combinations.insert(DMatrix::from_row_slice(nrows, ncols, &[]));
|
||||||
|
} else {
|
||||||
|
// Otherwise, we need to sample all possible matrices.
|
||||||
|
// To do this, we generate the values as the (multi) Cartesian product
|
||||||
|
// of the value sets. For example, for a 2x2 matrices, we consider
|
||||||
|
// all possible 4-element arrays that the matrices can take by
|
||||||
|
// considering all elements in the cartesian product
|
||||||
|
// V x V x V x V
|
||||||
|
// where V is the set of eligible values, e.g. V := -1 ..= 1
|
||||||
|
let values_iter = repeat(value_range.clone())
|
||||||
|
.take(n_values)
|
||||||
|
.multi_cartesian_product();
|
||||||
|
for matrix_values in values_iter {
|
||||||
|
all_combinations.insert(DMatrix::from_row_slice(
|
||||||
|
nrows,
|
||||||
|
ncols,
|
||||||
|
&matrix_values,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
all_combinations
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "slow-tests")]
|
||||||
|
#[test]
|
||||||
|
fn coo_no_duplicates_samples_all_admissible_outputs() {
|
||||||
|
// Note: This test basically mirrors a similar test for `matrix` in the `nalgebra` repo.
|
||||||
|
|
||||||
|
// Test that the proptest generation covers all possible outputs for a small space of inputs
|
||||||
|
// given enough samples.
|
||||||
|
|
||||||
|
// We use a deterministic test runner to make the test "stable".
|
||||||
|
let mut runner = TestRunner::deterministic();
|
||||||
|
|
||||||
|
// This number needs to be high enough so that we with high probability sample
|
||||||
|
// all possible cases
|
||||||
|
let num_generated_matrices = 500000;
|
||||||
|
|
||||||
|
let values = -1..=1;
|
||||||
|
let rows = 0..=2;
|
||||||
|
let cols = 0..=3;
|
||||||
|
let max_nnz = rows.end() * cols.end();
|
||||||
|
let strategy = coo_no_duplicates(values.clone(), rows.clone(), cols.clone(), max_nnz);
|
||||||
|
|
||||||
|
// Enumerate all possible combinations
|
||||||
|
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||||
|
|
||||||
|
let visited_combinations =
|
||||||
|
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||||
|
|
||||||
|
assert_eq!(visited_combinations.len(), all_combinations.len());
|
||||||
|
assert_eq!(
|
||||||
|
visited_combinations, all_combinations,
|
||||||
|
"Did not sample all possible values."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "slow-tests")]
|
||||||
|
#[test]
|
||||||
|
fn coo_with_duplicates_samples_all_admissible_outputs() {
|
||||||
|
// This is almost the same as the test for coo_no_duplicates, except that we need
|
||||||
|
// a different "success" criterion, since coo_with_duplicates is able to generate
|
||||||
|
// matrices with values outside of the value constraints. See below for details.
|
||||||
|
|
||||||
|
// We use a deterministic test runner to make the test "stable".
|
||||||
|
let mut runner = TestRunner::deterministic();
|
||||||
|
|
||||||
|
// This number needs to be high enough so that we with high probability sample
|
||||||
|
// all possible cases
|
||||||
|
let num_generated_matrices = 500000;
|
||||||
|
|
||||||
|
let values = -1..=1;
|
||||||
|
let rows = 0..=2;
|
||||||
|
let cols = 0..=3;
|
||||||
|
let max_nnz = rows.end() * cols.end();
|
||||||
|
let strategy = coo_with_duplicates(values.clone(), rows.clone(), cols.clone(), max_nnz, 2);
|
||||||
|
|
||||||
|
// Enumerate all possible combinations that fit the constraints
|
||||||
|
// (note: this is only a subset of the matrices that can be generated by
|
||||||
|
// `coo_with_duplicates`)
|
||||||
|
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||||
|
|
||||||
|
let visited_combinations =
|
||||||
|
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||||
|
|
||||||
|
// Here we cannot verify that the set of visited combinations is *equal* to
|
||||||
|
// all possible outcomes with the given constraints, however the
|
||||||
|
// strategy should be able to generate all matrices that fit the constraints.
|
||||||
|
// In other words, we need to determine that set of all admissible matrices
|
||||||
|
// is contained in the set of visited matrices
|
||||||
|
assert!(all_combinations.is_subset(&visited_combinations));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "slow-tests")]
|
||||||
|
#[test]
|
||||||
|
fn csr_samples_all_admissible_outputs() {
|
||||||
|
// We use a deterministic test runner to make the test "stable".
|
||||||
|
let mut runner = TestRunner::deterministic();
|
||||||
|
|
||||||
|
// This number needs to be high enough so that we with high probability sample
|
||||||
|
// all possible cases
|
||||||
|
let num_generated_matrices = 500000;
|
||||||
|
|
||||||
|
let values = -1..=1;
|
||||||
|
let rows = 0..=2;
|
||||||
|
let cols = 0..=3;
|
||||||
|
let max_nnz = rows.end() * cols.end();
|
||||||
|
let strategy = csr(values.clone(), rows.clone(), cols.clone(), max_nnz);
|
||||||
|
|
||||||
|
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||||
|
|
||||||
|
let visited_combinations =
|
||||||
|
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||||
|
|
||||||
|
assert_eq!(visited_combinations.len(), all_combinations.len());
|
||||||
|
assert_eq!(
|
||||||
|
visited_combinations, all_combinations,
|
||||||
|
"Did not sample all possible values."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "slow-tests")]
|
||||||
|
#[test]
|
||||||
|
fn csc_samples_all_admissible_outputs() {
|
||||||
|
// We use a deterministic test runner to make the test "stable".
|
||||||
|
let mut runner = TestRunner::deterministic();
|
||||||
|
|
||||||
|
// This number needs to be high enough so that we with high probability sample
|
||||||
|
// all possible cases
|
||||||
|
let num_generated_matrices = 500000;
|
||||||
|
|
||||||
|
let values = -1..=1;
|
||||||
|
let rows = 0..=2;
|
||||||
|
let cols = 0..=3;
|
||||||
|
let max_nnz = rows.end() * cols.end();
|
||||||
|
let strategy = csc(values.clone(), rows.clone(), cols.clone(), max_nnz);
|
||||||
|
|
||||||
|
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||||
|
|
||||||
|
let visited_combinations =
|
||||||
|
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||||
|
|
||||||
|
assert_eq!(visited_combinations.len(), all_combinations.len());
|
||||||
|
assert_eq!(
|
||||||
|
visited_combinations, all_combinations,
|
||||||
|
"Did not sample all possible values."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "slow-tests")]
|
||||||
|
#[test]
|
||||||
|
fn sparsity_pattern_samples_all_admissible_outputs() {
|
||||||
|
let mut runner = TestRunner::deterministic();
|
||||||
|
|
||||||
|
let num_generated_patterns = 50000;
|
||||||
|
|
||||||
|
let major_dims = 0..=2;
|
||||||
|
let minor_dims = 0..=3;
|
||||||
|
let max_nnz = major_dims.end() * minor_dims.end();
|
||||||
|
let strategy = sparsity_pattern(major_dims.clone(), minor_dims.clone(), max_nnz);
|
||||||
|
|
||||||
|
let visited_patterns: HashSet<_> = sample_strategy(strategy, &mut runner)
|
||||||
|
.take(num_generated_patterns)
|
||||||
|
.map(|pattern| {
|
||||||
|
// We represent patterns as dense matrices with 1 if an entry is occupied,
|
||||||
|
// 0 otherwise
|
||||||
|
let values = vec![1; pattern.nnz()];
|
||||||
|
CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
||||||
|
})
|
||||||
|
.map(|csr| DMatrix::from(&csr))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let all_possible_patterns = generate_all_possible_matrices(0..=1, major_dims, minor_dims);
|
||||||
|
|
||||||
|
assert_eq!(visited_patterns.len(), all_possible_patterns.len());
|
||||||
|
assert_eq!(visited_patterns, all_possible_patterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample_matrix_output_space<S>(
|
||||||
|
strategy: S,
|
||||||
|
runner: &mut TestRunner,
|
||||||
|
num_samples: usize,
|
||||||
|
) -> HashSet<DMatrix<i32>>
|
||||||
|
where
|
||||||
|
S: Strategy,
|
||||||
|
DMatrix<i32>: for<'b> From<&'b S::Value>,
|
||||||
|
{
|
||||||
|
sample_strategy(strategy, runner)
|
||||||
|
.take(num_samples)
|
||||||
|
.map(|matrix| DMatrix::from(&matrix))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample_strategy<'a, S: 'a + Strategy>(
|
||||||
|
strategy: S,
|
||||||
|
runner: &'a mut TestRunner,
|
||||||
|
) -> impl 'a + Iterator<Item = S::Value> {
|
||||||
|
repeat(()).map(move |_| {
|
||||||
|
let tree = strategy
|
||||||
|
.new_tree(runner)
|
||||||
|
.expect("Tree generation should not fail");
|
||||||
|
let value = tree.current();
|
||||||
|
value
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -128,6 +128,8 @@ pub mod geometry;
|
|||||||
#[cfg(feature = "io")]
|
#[cfg(feature = "io")]
|
||||||
pub mod io;
|
pub mod io;
|
||||||
pub mod linalg;
|
pub mod linalg;
|
||||||
|
#[cfg(feature = "proptest-support")]
|
||||||
|
pub mod proptest;
|
||||||
#[cfg(feature = "sparse")]
|
#[cfg(feature = "sparse")]
|
||||||
pub mod sparse;
|
pub mod sparse;
|
||||||
|
|
||||||
|
476
src/proptest/mod.rs
Normal file
476
src/proptest/mod.rs
Normal file
@ -0,0 +1,476 @@
|
|||||||
|
//! `proptest`-related features for `nalgebra` data structures.
|
||||||
|
//!
|
||||||
|
//! **This module is only available when the `proptest-support` feature is enabled in `nalgebra`**.
|
||||||
|
//!
|
||||||
|
//! `proptest` is a library for *property-based testing*. While similar to QuickCheck,
|
||||||
|
//! which may be more familiar to some users, it has a more sophisticated design that
|
||||||
|
//! provides users with automatic invariant-preserving shrinking. This means that when using
|
||||||
|
//! `proptest`, you rarely need to write your own shrinkers - which is usually very difficult -
|
||||||
|
//! and can instead get this "for free". Moreover, `proptest` does not rely on a canonical
|
||||||
|
//! `Arbitrary` trait implementation like QuickCheck, though it does also provide this. For
|
||||||
|
//! more information, check out the [proptest docs](https://docs.rs/proptest/0.10.1/proptest/)
|
||||||
|
//! and the [proptest book](https://altsysrq.github.io/proptest-book/intro.html).
|
||||||
|
//!
|
||||||
|
//! This module provides users of `nalgebra` with tools to work with `nalgebra` types in
|
||||||
|
//! `proptest` tests. At present, this integration is at an early stage, and only
|
||||||
|
//! provides tools for generating matrices and vectors, and not any of the geometry types.
|
||||||
|
//! There are essentially two ways of using this functionality:
|
||||||
|
//!
|
||||||
|
//! - Using the [matrix](fn.matrix.html) function to generate matrices with constraints
|
||||||
|
//! on dimensions and elements.
|
||||||
|
//! - Relying on the `Arbitrary` implementation of `MatrixMN`.
|
||||||
|
//!
|
||||||
|
//! The first variant is almost always preferred in practice. Read on to discover why.
|
||||||
|
//!
|
||||||
|
//! ### Using free function strategies
|
||||||
|
//!
|
||||||
|
//! In `proptest`, it is usually preferable to have free functions that generate *strategies*.
|
||||||
|
//! Currently, the [matrix](fn.matrix.html) function fills this role. The analogous function for
|
||||||
|
//! column vectors is [vector](fn.vector.html). Let's take a quick look at how it may be used:
|
||||||
|
//! ```rust
|
||||||
|
//! use nalgebra::proptest::matrix;
|
||||||
|
//! use proptest::prelude::*;
|
||||||
|
//!
|
||||||
|
//! proptest! {
|
||||||
|
//! # /*
|
||||||
|
//! #[test]
|
||||||
|
//! # */
|
||||||
|
//! fn my_test(a in matrix(-5 ..= 5, 2 ..= 4, 1..=4)) {
|
||||||
|
//! // Generates matrices with elements in the range -5 ..= 5, rows in 2..=4 and
|
||||||
|
//! // columns in 1..=4.
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! # fn main() { my_test(); }
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! In the above example, we generate matrices with constraints on the elements, as well as the
|
||||||
|
//! on the allowed dimensions. When a failing example is found, the resulting shrinking process
|
||||||
|
//! will preserve these invariants. We can use this to compose more advanced strategies.
|
||||||
|
//! For example, let's consider a toy example where we need to generate pairs of matrices
|
||||||
|
//! with exactly 3 rows fixed at compile-time and the same number of columns, but we want the
|
||||||
|
//! number of columns to vary. One way to do this is to use `proptest` combinators in combination
|
||||||
|
//! with [matrix](fn.matrix.html) as follows:
|
||||||
|
//!
|
||||||
|
//! ```rust
|
||||||
|
//! use nalgebra::{Dynamic, MatrixMN, U3};
|
||||||
|
//! use nalgebra::proptest::matrix;
|
||||||
|
//! use proptest::prelude::*;
|
||||||
|
//!
|
||||||
|
//! type MyMatrix = MatrixMN<i32, U3, Dynamic>;
|
||||||
|
//!
|
||||||
|
//! /// Returns a strategy for pairs of matrices with `U3` rows and the same number of
|
||||||
|
//! /// columns.
|
||||||
|
//! fn matrix_pairs() -> impl Strategy<Value=(MyMatrix, MyMatrix)> {
|
||||||
|
//! matrix(-5 ..= 5, U3, 0 ..= 10)
|
||||||
|
//! // We first generate the initial matrix `a`, and then depending on the concrete
|
||||||
|
//! // instances of `a`, we pick a second matrix with the same number of columns
|
||||||
|
//! .prop_flat_map(|a| {
|
||||||
|
//! let b = matrix(-5 .. 5, U3, a.ncols());
|
||||||
|
//! // This returns a new tuple strategy where we keep `a` fixed while
|
||||||
|
//! // the second item is a strategy that generates instances with the same
|
||||||
|
//! // dimensions as `a`
|
||||||
|
//! (Just(a), b)
|
||||||
|
//! })
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! proptest! {
|
||||||
|
//! # /*
|
||||||
|
//! #[test]
|
||||||
|
//! # */
|
||||||
|
//! fn my_test((a, b) in matrix_pairs()) {
|
||||||
|
//! // Let's double-check that the two matrices do indeed have the same number of
|
||||||
|
//! // columns
|
||||||
|
//! prop_assert_eq!(a.ncols(), b.ncols());
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! # fn main() { my_test(); }
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! ### The `Arbitrary` implementation
|
||||||
|
//!
|
||||||
|
//! If you don't care about the dimensions of matrices, you can write tests like these:
|
||||||
|
//!
|
||||||
|
//! ```rust
|
||||||
|
//! use nalgebra::{DMatrix, DVector, Dynamic, Matrix3, MatrixMN, Vector3, U3};
|
||||||
|
//! use proptest::prelude::*;
|
||||||
|
//!
|
||||||
|
//! proptest! {
|
||||||
|
//! # /*
|
||||||
|
//! #[test]
|
||||||
|
//! # */
|
||||||
|
//! fn test_dynamic(matrix: DMatrix<i32>) {
|
||||||
|
//! // This will generate arbitrary instances of `DMatrix` and also attempt
|
||||||
|
//! // to shrink/simplify them when test failures are encountered.
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! # /*
|
||||||
|
//! #[test]
|
||||||
|
//! # */
|
||||||
|
//! fn test_static_and_mixed(matrix: Matrix3<i32>, matrix2: MatrixMN<i32, U3, Dynamic>) {
|
||||||
|
//! // Test some property involving these matrices
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! # /*
|
||||||
|
//! #[test]
|
||||||
|
//! # */
|
||||||
|
//! fn test_vectors(fixed_size_vector: Vector3<i32>, dyn_vector: DVector<i32>) {
|
||||||
|
//! // Test some property involving these vectors
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! # fn main() { test_dynamic(); test_static_and_mixed(); test_vectors(); }
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! While this may be convenient, the default strategies for built-in types in `proptest` can
|
||||||
|
//! generate *any* number, including integers large enough to easily lead to overflow when used in
|
||||||
|
//! matrix operations, or even infinity or NaN values for floating-point types. Therefore
|
||||||
|
//! `Arbitrary` is rarely the method of choice for writing property-based tests.
|
||||||
|
//!
|
||||||
|
//! ### Notes on shrinking
|
||||||
|
//!
|
||||||
|
//! Due to some limitations of the current implementation, shrinking takes place by first
|
||||||
|
//! shrinking the matrix elements before trying to shrink the dimensions of the matrix.
|
||||||
|
//! This unfortunately often leads to the fact that a large number of shrinking iterations
|
||||||
|
//! are necessary to find a (nearly) minimal failing test case. As a workaround for this,
|
||||||
|
//! you can increase the maximum number of shrinking iterations when debugging. To do this,
|
||||||
|
//! simply set the `PROPTEST_MAX_SHRINK_ITERS` variable to a high number. For example:
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! PROPTEST_MAX_SHRINK_ITERS=100000 cargo test my_failing_test
|
||||||
|
//! ```
|
||||||
|
use crate::allocator::Allocator;
|
||||||
|
use crate::{DefaultAllocator, Dim, DimName, Dynamic, MatrixMN, Scalar, U1};
|
||||||
|
use proptest::arbitrary::Arbitrary;
|
||||||
|
use proptest::collection::vec;
|
||||||
|
use proptest::strategy::{BoxedStrategy, Just, NewTree, Strategy, ValueTree};
|
||||||
|
use proptest::test_runner::TestRunner;
|
||||||
|
|
||||||
|
use std::ops::RangeInclusive;
|
||||||
|
|
||||||
|
/// Parameters for arbitrary matrix generation.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub struct MatrixParameters<NParameters, R, C> {
|
||||||
|
/// The range of rows that may be generated.
|
||||||
|
pub rows: DimRange<R>,
|
||||||
|
/// The range of columns that may be generated.
|
||||||
|
pub cols: DimRange<C>,
|
||||||
|
/// Parameters for the `Arbitrary` implementation of the scalar values.
|
||||||
|
pub value_parameters: NParameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A range of allowed dimensions for use in generation of matrices.
|
||||||
|
///
|
||||||
|
/// The `DimRange` type is used to encode the range of dimensions that can be used for generation
|
||||||
|
/// of matrices with `proptest`. In most cases, you do not need to concern yourself with
|
||||||
|
/// `DimRange` directly, as it supports conversion from other types such as `U3` or inclusive
|
||||||
|
/// ranges such as `5 ..= 6`. The latter example corresponds to dimensions from (inclusive)
|
||||||
|
/// `Dynamic::new(5)` to `Dynamic::new(6)` (inclusive).
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct DimRange<D = Dynamic>(RangeInclusive<D>);
|
||||||
|
|
||||||
|
impl<D: Dim> DimRange<D> {
|
||||||
|
/// The lower bound for dimensions generated.
|
||||||
|
pub fn lower_bound(&self) -> D {
|
||||||
|
*self.0.start()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The upper bound for dimensions generated.
|
||||||
|
pub fn upper_bound(&self) -> D {
|
||||||
|
*self.0.end()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D: Dim> From<D> for DimRange<D> {
|
||||||
|
fn from(dim: D) -> Self {
|
||||||
|
DimRange(dim..=dim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D: Dim> From<RangeInclusive<D>> for DimRange<D> {
|
||||||
|
fn from(range: RangeInclusive<D>) -> Self {
|
||||||
|
DimRange(range)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<RangeInclusive<usize>> for DimRange<Dynamic> {
|
||||||
|
fn from(range: RangeInclusive<usize>) -> Self {
|
||||||
|
DimRange::from(Dynamic::new(*range.start())..=Dynamic::new(*range.end()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D: Dim> DimRange<D> {
|
||||||
|
/// Converts the `DimRange` into an instance of `RangeInclusive`.
|
||||||
|
pub fn to_range_inclusive(&self) -> RangeInclusive<usize> {
|
||||||
|
self.lower_bound().value()..=self.upper_bound().value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<usize> for DimRange<Dynamic> {
|
||||||
|
fn from(dim: usize) -> Self {
|
||||||
|
DimRange::from(Dynamic::new(dim))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The default range used for Dynamic dimensions when generating arbitrary matrices.
|
||||||
|
fn dynamic_dim_range() -> DimRange<Dynamic> {
|
||||||
|
DimRange::from(0..=6)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a strategy to generate matrices containing values drawn from the given strategy,
|
||||||
|
/// with rows and columns in the provided ranges.
|
||||||
|
///
|
||||||
|
/// ## Examples
|
||||||
|
/// ```
|
||||||
|
/// use nalgebra::proptest::matrix;
|
||||||
|
/// use nalgebra::{MatrixMN, U3, Dynamic};
|
||||||
|
/// use proptest::prelude::*;
|
||||||
|
///
|
||||||
|
/// proptest! {
|
||||||
|
/// # /*
|
||||||
|
/// #[test]
|
||||||
|
/// # */
|
||||||
|
/// fn my_test(a in matrix(0 .. 5i32, U3, 0 ..= 5)) {
|
||||||
|
/// // Let's make sure we've got the correct type first
|
||||||
|
/// let a: MatrixMN<_, U3, Dynamic> = a;
|
||||||
|
/// prop_assert!(a.nrows() == 3);
|
||||||
|
/// prop_assert!(a.ncols() <= 5);
|
||||||
|
/// prop_assert!(a.iter().all(|x_ij| *x_ij >= 0 && *x_ij < 5));
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// # fn main() { my_test(); }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// ## Limitations
|
||||||
|
/// The current implementation has some limitations that lead to suboptimal shrinking behavior.
|
||||||
|
/// See the [module-level documentation](index.html) for more.
|
||||||
|
pub fn matrix<R, C, ScalarStrategy>(
|
||||||
|
value_strategy: ScalarStrategy,
|
||||||
|
rows: impl Into<DimRange<R>>,
|
||||||
|
cols: impl Into<DimRange<C>>,
|
||||||
|
) -> MatrixStrategy<ScalarStrategy, R, C>
|
||||||
|
where
|
||||||
|
ScalarStrategy: Strategy + Clone + 'static,
|
||||||
|
ScalarStrategy::Value: Scalar,
|
||||||
|
R: Dim,
|
||||||
|
C: Dim,
|
||||||
|
DefaultAllocator: Allocator<ScalarStrategy::Value, R, C>,
|
||||||
|
{
|
||||||
|
matrix_(value_strategy, rows.into(), cols.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Same as `matrix`, but without the additional anonymous generic types
|
||||||
|
fn matrix_<R, C, ScalarStrategy>(
|
||||||
|
value_strategy: ScalarStrategy,
|
||||||
|
rows: DimRange<R>,
|
||||||
|
cols: DimRange<C>,
|
||||||
|
) -> MatrixStrategy<ScalarStrategy, R, C>
|
||||||
|
where
|
||||||
|
ScalarStrategy: Strategy + Clone + 'static,
|
||||||
|
ScalarStrategy::Value: Scalar,
|
||||||
|
R: Dim,
|
||||||
|
C: Dim,
|
||||||
|
DefaultAllocator: Allocator<ScalarStrategy::Value, R, C>,
|
||||||
|
{
|
||||||
|
let nrows = rows.lower_bound().value()..=rows.upper_bound().value();
|
||||||
|
let ncols = cols.lower_bound().value()..=cols.upper_bound().value();
|
||||||
|
|
||||||
|
// Even though we can use this function to generate fixed-size matrices,
|
||||||
|
// we currently generate all matrices with heap allocated Vec data.
|
||||||
|
// TODO: Avoid heap allocation for fixed-size matrices.
|
||||||
|
// Doing this *properly* would probably require us to implement a custom
|
||||||
|
// strategy and valuetree with custom shrinking logic, which is not trivial
|
||||||
|
|
||||||
|
// Perhaps more problematic, however, is the poor shrinking behavior the current setup leads to.
|
||||||
|
// Shrinking in proptest basically happens in "reverse" of the combinators, so
|
||||||
|
// by first generating the dimensions and then the elements, we get shrinking that first
|
||||||
|
// tries to completely shrink the individual elements before trying to reduce the dimension.
|
||||||
|
// This is clearly the opposite of what we want. I can't find any good way around this
|
||||||
|
// short of writing our own custom value tree, which we should probably do at some point.
|
||||||
|
// TODO: Custom implementation of value tree for better shrinking behavior.
|
||||||
|
|
||||||
|
let strategy = nrows
|
||||||
|
.prop_flat_map(move |nrows| (Just(nrows), ncols.clone()))
|
||||||
|
.prop_flat_map(move |(nrows, ncols)| {
|
||||||
|
(
|
||||||
|
Just(nrows),
|
||||||
|
Just(ncols),
|
||||||
|
vec(value_strategy.clone(), nrows * ncols),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.prop_map(|(nrows, ncols, values)| {
|
||||||
|
// Note: R/C::from_usize will panic if nrows/ncols does not fit in the dimension type.
|
||||||
|
// However, this should never fail, because we should only be generating
|
||||||
|
// this stuff in the first place
|
||||||
|
MatrixMN::from_iterator_generic(R::from_usize(nrows), C::from_usize(ncols), values)
|
||||||
|
})
|
||||||
|
.boxed();
|
||||||
|
|
||||||
|
MatrixStrategy { strategy }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a strategy to generate column vectors containing values drawn from the given strategy,
|
||||||
|
/// with length in the provided range.
|
||||||
|
///
|
||||||
|
/// This is a convenience function for calling
|
||||||
|
/// [matrix(value_strategy, length, U1)](fn.matrix.html) and should
|
||||||
|
/// be used when you only want to generate column vectors, as it's simpler and makes the intent
|
||||||
|
/// clear.
|
||||||
|
pub fn vector<D, ScalarStrategy>(
|
||||||
|
value_strategy: ScalarStrategy,
|
||||||
|
length: impl Into<DimRange<D>>,
|
||||||
|
) -> MatrixStrategy<ScalarStrategy, D, U1>
|
||||||
|
where
|
||||||
|
ScalarStrategy: Strategy + Clone + 'static,
|
||||||
|
ScalarStrategy::Value: Scalar,
|
||||||
|
D: Dim,
|
||||||
|
DefaultAllocator: Allocator<ScalarStrategy::Value, D>,
|
||||||
|
{
|
||||||
|
matrix_(value_strategy, length.into(), U1.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<NParameters, R, C> Default for MatrixParameters<NParameters, R, C>
|
||||||
|
where
|
||||||
|
NParameters: Default,
|
||||||
|
R: DimName,
|
||||||
|
C: DimName,
|
||||||
|
{
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
rows: DimRange::from(R::name()),
|
||||||
|
cols: DimRange::from(C::name()),
|
||||||
|
value_parameters: NParameters::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<NParameters, R> Default for MatrixParameters<NParameters, R, Dynamic>
|
||||||
|
where
|
||||||
|
NParameters: Default,
|
||||||
|
R: DimName,
|
||||||
|
{
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
rows: DimRange::from(R::name()),
|
||||||
|
cols: dynamic_dim_range(),
|
||||||
|
value_parameters: NParameters::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<NParameters, C> Default for MatrixParameters<NParameters, Dynamic, C>
|
||||||
|
where
|
||||||
|
NParameters: Default,
|
||||||
|
C: DimName,
|
||||||
|
{
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
rows: dynamic_dim_range(),
|
||||||
|
cols: DimRange::from(C::name()),
|
||||||
|
value_parameters: NParameters::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<NParameters> Default for MatrixParameters<NParameters, Dynamic, Dynamic>
|
||||||
|
where
|
||||||
|
NParameters: Default,
|
||||||
|
{
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
rows: dynamic_dim_range(),
|
||||||
|
cols: dynamic_dim_range(),
|
||||||
|
value_parameters: NParameters::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<N, R, C> Arbitrary for MatrixMN<N, R, C>
|
||||||
|
where
|
||||||
|
N: Scalar + Arbitrary,
|
||||||
|
<N as Arbitrary>::Strategy: Clone,
|
||||||
|
R: Dim,
|
||||||
|
C: Dim,
|
||||||
|
MatrixParameters<N::Parameters, R, C>: Default,
|
||||||
|
DefaultAllocator: Allocator<N, R, C>,
|
||||||
|
{
|
||||||
|
type Parameters = MatrixParameters<N::Parameters, R, C>;
|
||||||
|
|
||||||
|
fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
|
||||||
|
let value_strategy = N::arbitrary_with(args.value_parameters);
|
||||||
|
matrix(value_strategy, args.rows, args.cols)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Strategy = MatrixStrategy<N::Strategy, R, C>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A strategy for generating matrices.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct MatrixStrategy<NStrategy, R: Dim, C: Dim>
|
||||||
|
where
|
||||||
|
NStrategy: Strategy,
|
||||||
|
NStrategy::Value: Scalar,
|
||||||
|
DefaultAllocator: Allocator<NStrategy::Value, R, C>,
|
||||||
|
{
|
||||||
|
// For now we only internally hold a boxed strategy. The reason for introducing this
|
||||||
|
// separate wrapper struct is so that we can replace the strategy logic with custom logic
|
||||||
|
// later down the road without introducing significant breaking changes
|
||||||
|
strategy: BoxedStrategy<MatrixMN<NStrategy::Value, R, C>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<NStrategy, R, C> Strategy for MatrixStrategy<NStrategy, R, C>
|
||||||
|
where
|
||||||
|
NStrategy: Strategy,
|
||||||
|
NStrategy::Value: Scalar,
|
||||||
|
R: Dim,
|
||||||
|
C: Dim,
|
||||||
|
DefaultAllocator: Allocator<NStrategy::Value, R, C>,
|
||||||
|
{
|
||||||
|
type Tree = MatrixValueTree<NStrategy::Value, R, C>;
|
||||||
|
type Value = MatrixMN<NStrategy::Value, R, C>;
|
||||||
|
|
||||||
|
fn new_tree(&self, runner: &mut TestRunner) -> NewTree<Self> {
|
||||||
|
let underlying_tree = self.strategy.new_tree(runner)?;
|
||||||
|
Ok(MatrixValueTree {
|
||||||
|
value_tree: underlying_tree,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A value tree for matrices.
|
||||||
|
pub struct MatrixValueTree<N, R, C>
|
||||||
|
where
|
||||||
|
N: Scalar,
|
||||||
|
R: Dim,
|
||||||
|
C: Dim,
|
||||||
|
DefaultAllocator: Allocator<N, R, C>,
|
||||||
|
{
|
||||||
|
// For now we only wrap a boxed value tree. The reason for wrapping is that this allows us
|
||||||
|
// to swap out the value tree logic down the road without significant breaking changes.
|
||||||
|
value_tree: Box<dyn ValueTree<Value = MatrixMN<N, R, C>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<N, R, C> ValueTree for MatrixValueTree<N, R, C>
|
||||||
|
where
|
||||||
|
N: Scalar,
|
||||||
|
R: Dim,
|
||||||
|
C: Dim,
|
||||||
|
DefaultAllocator: Allocator<N, R, C>,
|
||||||
|
{
|
||||||
|
type Value = MatrixMN<N, R, C>;
|
||||||
|
|
||||||
|
fn current(&self) -> Self::Value {
|
||||||
|
self.value_tree.current()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn simplify(&mut self) -> bool {
|
||||||
|
self.value_tree.simplify()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn complicate(&mut self) -> bool {
|
||||||
|
self.value_tree.complicate()
|
||||||
|
}
|
||||||
|
}
|
@ -19,5 +19,9 @@ extern crate quickcheck;
|
|||||||
mod core;
|
mod core;
|
||||||
mod geometry;
|
mod geometry;
|
||||||
mod linalg;
|
mod linalg;
|
||||||
|
|
||||||
|
#[cfg(feature = "proptest-support")]
|
||||||
|
mod proptest;
|
||||||
|
|
||||||
//#[cfg(feature = "sparse")]
|
//#[cfg(feature = "sparse")]
|
||||||
//mod sparse;
|
//mod sparse;
|
||||||
|
219
tests/proptest/mod.rs
Normal file
219
tests/proptest/mod.rs
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
//! Tests for proptest-related functionality.
|
||||||
|
use nalgebra::base::dimension::*;
|
||||||
|
use nalgebra::proptest::{matrix, DimRange, MatrixStrategy};
|
||||||
|
use nalgebra::{DMatrix, DVector, Dim, Matrix3, MatrixMN, Vector3};
|
||||||
|
use proptest::prelude::*;
|
||||||
|
use proptest::strategy::ValueTree;
|
||||||
|
use proptest::test_runner::TestRunner;
|
||||||
|
|
||||||
|
/// Generate a proptest that tests that all matrices generated with the
|
||||||
|
/// provided rows and columns conform to the constraints defined by the
|
||||||
|
/// input.
|
||||||
|
macro_rules! generate_matrix_sanity_test {
|
||||||
|
($test_name:ident, $rows:expr, $cols:expr) => {
|
||||||
|
proptest! {
|
||||||
|
#[test]
|
||||||
|
fn $test_name(a in matrix(-5 ..= 5i32, $rows, $cols)) {
|
||||||
|
// let a: MatrixMN<_, $rows, $cols> = a;
|
||||||
|
let rows_range = DimRange::from($rows);
|
||||||
|
let cols_range = DimRange::from($cols);
|
||||||
|
prop_assert!(a.nrows() >= rows_range.lower_bound().value()
|
||||||
|
&& a.nrows() <= rows_range.upper_bound().value());
|
||||||
|
prop_assert!(a.ncols() >= cols_range.lower_bound().value()
|
||||||
|
&& a.ncols() <= cols_range.upper_bound().value());
|
||||||
|
prop_assert!(a.iter().all(|x_ij| *x_ij >= -5 && *x_ij <= 5));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test all fixed-size matrices with row/col dimensions up to 3
|
||||||
|
generate_matrix_sanity_test!(test_matrix_u0_u0, U0, U0);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_u1_u0, U1, U0);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_u0_u1, U0, U1);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_u1_u1, U1, U1);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_u2_u1, U2, U1);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_u1_u2, U1, U2);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_u2_u2, U2, U2);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_u3_u2, U3, U2);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_u2_u3, U2, U3);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_u3_u3, U3, U3);
|
||||||
|
|
||||||
|
// Similarly test all heap-allocated but fixed dim ranges
|
||||||
|
generate_matrix_sanity_test!(test_matrix_0_0, 0, 0);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_0_1, 0, 1);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_1_0, 1, 0);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_1_1, 1, 1);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_2_1, 2, 1);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_1_2, 1, 2);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_2_2, 2, 2);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_3_2, 3, 2);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_2_3, 2, 3);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_3_3, 3, 3);
|
||||||
|
|
||||||
|
// Test arbitrary inputs
|
||||||
|
generate_matrix_sanity_test!(test_matrix_input_1, U5, 1..=5);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_input_2, 3..=4, 1..=5);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_input_3, 1..=2, U3);
|
||||||
|
generate_matrix_sanity_test!(test_matrix_input_4, 3, U4);
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_matrix_output_types() {
|
||||||
|
// Test that the dimension types are correct for the given inputs
|
||||||
|
let _: MatrixStrategy<_, U3, U4> = matrix(-5..5, U3, U4);
|
||||||
|
let _: MatrixStrategy<_, U3, U3> = matrix(-5..5, U3, U3);
|
||||||
|
let _: MatrixStrategy<_, U3, Dynamic> = matrix(-5..5, U3, 1..=5);
|
||||||
|
let _: MatrixStrategy<_, Dynamic, U3> = matrix(-5..5, 1..=5, U3);
|
||||||
|
let _: MatrixStrategy<_, Dynamic, Dynamic> = matrix(-5..5, 1..=5, 1..=5);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Below we have some tests to ensure that specific instances of MatrixMN are usable
|
||||||
|
// in a typical proptest scenario where we (implicitly) use the `Arbitrary` trait
|
||||||
|
proptest! {
|
||||||
|
#[test]
|
||||||
|
fn ensure_arbitrary_test_compiles_matrix3(_: Matrix3<i32>) {}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ensure_arbitrary_test_compiles_matrixmn_u3_dynamic(_: MatrixMN<i32, U3, Dynamic>) {}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ensure_arbitrary_test_compiles_matrixmn_dynamic_u3(_: MatrixMN<i32, Dynamic, U3>) {}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ensure_arbitrary_test_compiles_dmatrix(_: DMatrix<i32>) {}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ensure_arbitrary_test_compiles_vector3(_: Vector3<i32>) {}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ensure_arbitrary_test_compiles_dvector(_: DVector<i32>) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matrix_shrinking_satisfies_constraints() {
|
||||||
|
// We use a deterministic test runner to make the test "stable".
|
||||||
|
let mut runner = TestRunner::deterministic();
|
||||||
|
|
||||||
|
let strategy = matrix(-1..=2, 1..=3, 2..=4);
|
||||||
|
|
||||||
|
let num_matrices = 25;
|
||||||
|
|
||||||
|
macro_rules! maybeprintln {
|
||||||
|
($($arg:tt)*) => {
|
||||||
|
// Uncomment the below line to enable printing of matrix sequences. This is handy
|
||||||
|
// for manually inspecting the sequences of simplified matrices.
|
||||||
|
// println!($($arg)*)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
maybeprintln!("========================== (begin generation process)");
|
||||||
|
|
||||||
|
for _ in 0..num_matrices {
|
||||||
|
let mut tree = strategy
|
||||||
|
.new_tree(&mut runner)
|
||||||
|
.expect("Tree generation should not fail.");
|
||||||
|
|
||||||
|
let mut current = Some(tree.current());
|
||||||
|
|
||||||
|
maybeprintln!("------------------");
|
||||||
|
|
||||||
|
while let Some(matrix) = current {
|
||||||
|
maybeprintln!("{}", matrix);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matrix.iter().all(|&v| v >= -1 && v <= 2),
|
||||||
|
"All matrix elements must satisfy constraints"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
matrix.nrows() >= 1 && matrix.nrows() <= 3,
|
||||||
|
"Number of rows in matrix must satisfy constraints."
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
matrix.ncols() >= 2 && matrix.ncols() <= 4,
|
||||||
|
"Number of columns in matrix must satisfy constraints."
|
||||||
|
);
|
||||||
|
|
||||||
|
current = if tree.simplify() {
|
||||||
|
Some(tree.current())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
maybeprintln!("========================== (end of generation process)");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "slow-tests")]
|
||||||
|
mod slow {
|
||||||
|
use super::*;
|
||||||
|
use itertools::Itertools;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::iter::repeat;
|
||||||
|
|
||||||
|
#[cfg(feature = "slow-tests")]
|
||||||
|
#[test]
|
||||||
|
fn matrix_samples_all_possible_outputs() {
|
||||||
|
// Test that the proptest generation covers all possible outputs for a small space of inputs
|
||||||
|
// given enough samples.
|
||||||
|
|
||||||
|
// We use a deterministic test runner to make the test "stable".
|
||||||
|
let mut runner = TestRunner::deterministic();
|
||||||
|
|
||||||
|
// This number needs to be high enough so that we with high probability sample
|
||||||
|
// all possible cases
|
||||||
|
let num_generated_matrices = 200000;
|
||||||
|
|
||||||
|
let values = -1..=1;
|
||||||
|
let rows = 0..=2;
|
||||||
|
let cols = 0..=3;
|
||||||
|
let strategy = matrix(values.clone(), rows.clone(), cols.clone());
|
||||||
|
|
||||||
|
// Enumerate all possible combinations
|
||||||
|
let mut all_combinations = HashSet::new();
|
||||||
|
for nrows in rows {
|
||||||
|
for ncols in cols.clone() {
|
||||||
|
// For the given number of rows and columns
|
||||||
|
let n_values = nrows * ncols;
|
||||||
|
|
||||||
|
if n_values == 0 {
|
||||||
|
// If we have zero rows or columns, the set of matrices with the given
|
||||||
|
// rows and columns is a single element: an empty matrix
|
||||||
|
all_combinations.insert(DMatrix::from_row_slice(nrows, ncols, &[]));
|
||||||
|
} else {
|
||||||
|
// Otherwise, we need to sample all possible matrices.
|
||||||
|
// To do this, we generate the values as the (multi) Cartesian product
|
||||||
|
// of the value sets. For example, for a 2x2 matrices, we consider
|
||||||
|
// all possible 4-element arrays that the matrices can take by
|
||||||
|
// considering all elements in the cartesian product
|
||||||
|
// V x V x V x V
|
||||||
|
// where V is the set of eligible values, e.g. V := -1 ..= 1
|
||||||
|
for matrix_values in repeat(values.clone())
|
||||||
|
.take(n_values)
|
||||||
|
.multi_cartesian_product()
|
||||||
|
{
|
||||||
|
all_combinations.insert(DMatrix::from_row_slice(
|
||||||
|
nrows,
|
||||||
|
ncols,
|
||||||
|
&matrix_values,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut visited_combinations = HashSet::new();
|
||||||
|
for _ in 0..num_generated_matrices {
|
||||||
|
let tree = strategy
|
||||||
|
.new_tree(&mut runner)
|
||||||
|
.expect("Tree generation should not fail");
|
||||||
|
let matrix = tree.current();
|
||||||
|
visited_combinations.insert(matrix.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
visited_combinations, all_combinations,
|
||||||
|
"Did not sample all possible values."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user