rustfmt
This commit is contained in:
parent
795d818ae5
commit
7473d54d74
|
@ -1,17 +1,17 @@
|
||||||
use crate::coo::CooMatrix;
|
|
||||||
use crate::convert::serial::*;
|
use crate::convert::serial::*;
|
||||||
use nalgebra::{Matrix, Scalar, Dim, ClosedAdd, DMatrix};
|
use crate::coo::CooMatrix;
|
||||||
use nalgebra::storage::{Storage};
|
|
||||||
use num_traits::Zero;
|
|
||||||
use crate::csr::CsrMatrix;
|
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
|
use crate::csr::CsrMatrix;
|
||||||
|
use nalgebra::storage::Storage;
|
||||||
|
use nalgebra::{ClosedAdd, DMatrix, Dim, Matrix, Scalar};
|
||||||
|
use num_traits::Zero;
|
||||||
|
|
||||||
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CooMatrix<T>
|
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CooMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar + Zero,
|
T: Scalar + Zero,
|
||||||
R: Dim,
|
R: Dim,
|
||||||
C: Dim,
|
C: Dim,
|
||||||
S: Storage<T, R, C>
|
S: Storage<T, R, C>,
|
||||||
{
|
{
|
||||||
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
|
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
|
||||||
convert_dense_coo(matrix)
|
convert_dense_coo(matrix)
|
||||||
|
@ -29,7 +29,7 @@ where
|
||||||
|
|
||||||
impl<'a, T> From<&'a CooMatrix<T>> for CsrMatrix<T>
|
impl<'a, T> From<&'a CooMatrix<T>> for CsrMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar + Zero + ClosedAdd
|
T: Scalar + Zero + ClosedAdd,
|
||||||
{
|
{
|
||||||
fn from(matrix: &'a CooMatrix<T>) -> Self {
|
fn from(matrix: &'a CooMatrix<T>) -> Self {
|
||||||
convert_coo_csr(matrix)
|
convert_coo_csr(matrix)
|
||||||
|
@ -38,7 +38,7 @@ where
|
||||||
|
|
||||||
impl<'a, T> From<&'a CsrMatrix<T>> for CooMatrix<T>
|
impl<'a, T> From<&'a CsrMatrix<T>> for CooMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar + Zero + ClosedAdd
|
T: Scalar + Zero + ClosedAdd,
|
||||||
{
|
{
|
||||||
fn from(matrix: &'a CsrMatrix<T>) -> Self {
|
fn from(matrix: &'a CsrMatrix<T>) -> Self {
|
||||||
convert_csr_coo(matrix)
|
convert_csr_coo(matrix)
|
||||||
|
@ -50,7 +50,7 @@ where
|
||||||
T: Scalar + Zero,
|
T: Scalar + Zero,
|
||||||
R: Dim,
|
R: Dim,
|
||||||
C: Dim,
|
C: Dim,
|
||||||
S: Storage<T, R, C>
|
S: Storage<T, R, C>,
|
||||||
{
|
{
|
||||||
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
|
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
|
||||||
convert_dense_csr(matrix)
|
convert_dense_csr(matrix)
|
||||||
|
@ -59,7 +59,7 @@ where
|
||||||
|
|
||||||
impl<'a, T> From<&'a CsrMatrix<T>> for DMatrix<T>
|
impl<'a, T> From<&'a CsrMatrix<T>> for DMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar + Zero + ClosedAdd
|
T: Scalar + Zero + ClosedAdd,
|
||||||
{
|
{
|
||||||
fn from(matrix: &'a CsrMatrix<T>) -> Self {
|
fn from(matrix: &'a CsrMatrix<T>) -> Self {
|
||||||
convert_csr_dense(matrix)
|
convert_csr_dense(matrix)
|
||||||
|
@ -68,7 +68,7 @@ where
|
||||||
|
|
||||||
impl<'a, T> From<&'a CooMatrix<T>> for CscMatrix<T>
|
impl<'a, T> From<&'a CooMatrix<T>> for CscMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar + Zero + ClosedAdd
|
T: Scalar + Zero + ClosedAdd,
|
||||||
{
|
{
|
||||||
fn from(matrix: &'a CooMatrix<T>) -> Self {
|
fn from(matrix: &'a CooMatrix<T>) -> Self {
|
||||||
convert_coo_csc(matrix)
|
convert_coo_csc(matrix)
|
||||||
|
@ -77,7 +77,7 @@ where
|
||||||
|
|
||||||
impl<'a, T> From<&'a CscMatrix<T>> for CooMatrix<T>
|
impl<'a, T> From<&'a CscMatrix<T>> for CooMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar + Zero
|
T: Scalar + Zero,
|
||||||
{
|
{
|
||||||
fn from(matrix: &'a CscMatrix<T>) -> Self {
|
fn from(matrix: &'a CscMatrix<T>) -> Self {
|
||||||
convert_csc_coo(matrix)
|
convert_csc_coo(matrix)
|
||||||
|
@ -85,11 +85,11 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CscMatrix<T>
|
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CscMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar + Zero,
|
T: Scalar + Zero,
|
||||||
R: Dim,
|
R: Dim,
|
||||||
C: Dim,
|
C: Dim,
|
||||||
S: Storage<T, R, C>
|
S: Storage<T, R, C>,
|
||||||
{
|
{
|
||||||
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
|
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
|
||||||
convert_dense_csc(matrix)
|
convert_dense_csc(matrix)
|
||||||
|
@ -97,8 +97,8 @@ impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CscMatrix<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> From<&'a CscMatrix<T>> for DMatrix<T>
|
impl<'a, T> From<&'a CscMatrix<T>> for DMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar + Zero + ClosedAdd
|
T: Scalar + Zero + ClosedAdd,
|
||||||
{
|
{
|
||||||
fn from(matrix: &'a CscMatrix<T>) -> Self {
|
fn from(matrix: &'a CscMatrix<T>) -> Self {
|
||||||
convert_csc_dense(matrix)
|
convert_csc_dense(matrix)
|
||||||
|
@ -106,8 +106,8 @@ impl<'a, T> From<&'a CscMatrix<T>> for DMatrix<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> From<&'a CscMatrix<T>> for CsrMatrix<T>
|
impl<'a, T> From<&'a CscMatrix<T>> for CsrMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar
|
T: Scalar,
|
||||||
{
|
{
|
||||||
fn from(matrix: &'a CscMatrix<T>) -> Self {
|
fn from(matrix: &'a CscMatrix<T>) -> Self {
|
||||||
convert_csc_csr(matrix)
|
convert_csc_csr(matrix)
|
||||||
|
@ -116,9 +116,9 @@ impl<'a, T> From<&'a CscMatrix<T>> for CsrMatrix<T>
|
||||||
|
|
||||||
impl<'a, T> From<&'a CsrMatrix<T>> for CscMatrix<T>
|
impl<'a, T> From<&'a CsrMatrix<T>> for CscMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar
|
T: Scalar,
|
||||||
{
|
{
|
||||||
fn from(matrix: &'a CsrMatrix<T>) -> Self {
|
fn from(matrix: &'a CsrMatrix<T>) -> Self {
|
||||||
convert_csr_csc(matrix)
|
convert_csr_csc(matrix)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,8 +7,8 @@ use std::ops::Add;
|
||||||
|
|
||||||
use num_traits::Zero;
|
use num_traits::Zero;
|
||||||
|
|
||||||
use nalgebra::{ClosedAdd, Dim, DMatrix, Matrix, Scalar};
|
|
||||||
use nalgebra::storage::Storage;
|
use nalgebra::storage::Storage;
|
||||||
|
use nalgebra::{ClosedAdd, DMatrix, Dim, Matrix, Scalar};
|
||||||
|
|
||||||
use crate::coo::CooMatrix;
|
use crate::coo::CooMatrix;
|
||||||
use crate::cs;
|
use crate::cs;
|
||||||
|
@ -21,7 +21,7 @@ where
|
||||||
T: Scalar + Zero,
|
T: Scalar + Zero,
|
||||||
R: Dim,
|
R: Dim,
|
||||||
C: Dim,
|
C: Dim,
|
||||||
S: Storage<T, R, C>
|
S: Storage<T, R, C>,
|
||||||
{
|
{
|
||||||
let mut coo = CooMatrix::new(dense.nrows(), dense.ncols());
|
let mut coo = CooMatrix::new(dense.nrows(), dense.ncols());
|
||||||
|
|
||||||
|
@ -52,12 +52,14 @@ where
|
||||||
/// Converts a [`CooMatrix`] to a [`CsrMatrix`].
|
/// Converts a [`CooMatrix`] to a [`CsrMatrix`].
|
||||||
pub fn convert_coo_csr<T>(coo: &CooMatrix<T>) -> CsrMatrix<T>
|
pub fn convert_coo_csr<T>(coo: &CooMatrix<T>) -> CsrMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar + Zero
|
T: Scalar + Zero,
|
||||||
{
|
{
|
||||||
let (offsets, indices, values) = convert_coo_cs(coo.nrows(),
|
let (offsets, indices, values) = convert_coo_cs(
|
||||||
coo.row_indices(),
|
coo.nrows(),
|
||||||
coo.col_indices(),
|
coo.row_indices(),
|
||||||
coo.values());
|
coo.col_indices(),
|
||||||
|
coo.values(),
|
||||||
|
);
|
||||||
|
|
||||||
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
|
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
|
||||||
// to see if it can be justified for performance reasons)
|
// to see if it can be justified for performance reasons)
|
||||||
|
@ -66,8 +68,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Converts a [`CsrMatrix`] to a [`CooMatrix`].
|
/// Converts a [`CsrMatrix`] to a [`CooMatrix`].
|
||||||
pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T>
|
pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T> {
|
||||||
{
|
|
||||||
let mut result = CooMatrix::new(csr.nrows(), csr.ncols());
|
let mut result = CooMatrix::new(csr.nrows(), csr.ncols());
|
||||||
for (i, j, v) in csr.triplet_iter() {
|
for (i, j, v) in csr.triplet_iter() {
|
||||||
result.push(i, j, v.inlined_clone());
|
result.push(i, j, v.inlined_clone());
|
||||||
|
@ -76,9 +77,9 @@ pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Converts a [`CsrMatrix`] to a dense matrix.
|
/// Converts a [`CsrMatrix`] to a dense matrix.
|
||||||
pub fn convert_csr_dense<T>(csr:& CsrMatrix<T>) -> DMatrix<T>
|
pub fn convert_csr_dense<T>(csr: &CsrMatrix<T>) -> DMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar + ClosedAdd + Zero
|
T: Scalar + ClosedAdd + Zero,
|
||||||
{
|
{
|
||||||
let mut output = DMatrix::zeros(csr.nrows(), csr.ncols());
|
let mut output = DMatrix::zeros(csr.nrows(), csr.ncols());
|
||||||
|
|
||||||
|
@ -95,7 +96,7 @@ where
|
||||||
T: Scalar + Zero,
|
T: Scalar + Zero,
|
||||||
R: Dim,
|
R: Dim,
|
||||||
C: Dim,
|
C: Dim,
|
||||||
S: Storage<T, R, C>
|
S: Storage<T, R, C>,
|
||||||
{
|
{
|
||||||
let mut row_offsets = Vec::with_capacity(dense.nrows() + 1);
|
let mut row_offsets = Vec::with_capacity(dense.nrows() + 1);
|
||||||
let mut col_idx = Vec::new();
|
let mut col_idx = Vec::new();
|
||||||
|
@ -105,8 +106,8 @@ where
|
||||||
// nalgebra's column-major storage. The alternative would be to perform an initial sweep
|
// nalgebra's column-major storage. The alternative would be to perform an initial sweep
|
||||||
// to count number of non-zeros per row.
|
// to count number of non-zeros per row.
|
||||||
row_offsets.push(0);
|
row_offsets.push(0);
|
||||||
for i in 0 .. dense.nrows() {
|
for i in 0..dense.nrows() {
|
||||||
for j in 0 .. dense.ncols() {
|
for j in 0..dense.ncols() {
|
||||||
let v = dense.index((i, j));
|
let v = dense.index((i, j));
|
||||||
if v != &T::zero() {
|
if v != &T::zero() {
|
||||||
col_idx.push(j);
|
col_idx.push(j);
|
||||||
|
@ -125,12 +126,14 @@ where
|
||||||
/// Converts a [`CooMatrix`] to a [`CscMatrix`].
|
/// Converts a [`CooMatrix`] to a [`CscMatrix`].
|
||||||
pub fn convert_coo_csc<T>(coo: &CooMatrix<T>) -> CscMatrix<T>
|
pub fn convert_coo_csc<T>(coo: &CooMatrix<T>) -> CscMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar + Zero
|
T: Scalar + Zero,
|
||||||
{
|
{
|
||||||
let (offsets, indices, values) = convert_coo_cs(coo.ncols(),
|
let (offsets, indices, values) = convert_coo_cs(
|
||||||
coo.col_indices(),
|
coo.ncols(),
|
||||||
coo.row_indices(),
|
coo.col_indices(),
|
||||||
coo.values());
|
coo.row_indices(),
|
||||||
|
coo.values(),
|
||||||
|
);
|
||||||
|
|
||||||
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
|
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
|
||||||
// to see if it can be justified for performance reasons)
|
// to see if it can be justified for performance reasons)
|
||||||
|
@ -141,7 +144,7 @@ where
|
||||||
/// Converts a [`CscMatrix`] to a [`CooMatrix`].
|
/// Converts a [`CscMatrix`] to a [`CooMatrix`].
|
||||||
pub fn convert_csc_coo<T>(csc: &CscMatrix<T>) -> CooMatrix<T>
|
pub fn convert_csc_coo<T>(csc: &CscMatrix<T>) -> CooMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar
|
T: Scalar,
|
||||||
{
|
{
|
||||||
let mut coo = CooMatrix::new(csc.nrows(), csc.ncols());
|
let mut coo = CooMatrix::new(csc.nrows(), csc.ncols());
|
||||||
for (i, j, v) in csc.triplet_iter() {
|
for (i, j, v) in csc.triplet_iter() {
|
||||||
|
@ -153,7 +156,7 @@ where
|
||||||
/// Converts a [`CscMatrix`] to a dense matrix.
|
/// Converts a [`CscMatrix`] to a dense matrix.
|
||||||
pub fn convert_csc_dense<T>(csc: &CscMatrix<T>) -> DMatrix<T>
|
pub fn convert_csc_dense<T>(csc: &CscMatrix<T>) -> DMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar + ClosedAdd + Zero
|
T: Scalar + ClosedAdd + Zero,
|
||||||
{
|
{
|
||||||
let mut output = DMatrix::zeros(csc.nrows(), csc.ncols());
|
let mut output = DMatrix::zeros(csc.nrows(), csc.ncols());
|
||||||
|
|
||||||
|
@ -166,19 +169,19 @@ where
|
||||||
|
|
||||||
/// Converts a dense matrix to a [`CscMatrix`].
|
/// Converts a dense matrix to a [`CscMatrix`].
|
||||||
pub fn convert_dense_csc<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CscMatrix<T>
|
pub fn convert_dense_csc<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CscMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar + Zero,
|
T: Scalar + Zero,
|
||||||
R: Dim,
|
R: Dim,
|
||||||
C: Dim,
|
C: Dim,
|
||||||
S: Storage<T, R, C>
|
S: Storage<T, R, C>,
|
||||||
{
|
{
|
||||||
let mut col_offsets = Vec::with_capacity(dense.ncols() + 1);
|
let mut col_offsets = Vec::with_capacity(dense.ncols() + 1);
|
||||||
let mut row_idx = Vec::new();
|
let mut row_idx = Vec::new();
|
||||||
let mut values = Vec::new();
|
let mut values = Vec::new();
|
||||||
|
|
||||||
col_offsets.push(0);
|
col_offsets.push(0);
|
||||||
for j in 0 .. dense.ncols() {
|
for j in 0..dense.ncols() {
|
||||||
for i in 0 .. dense.nrows() {
|
for i in 0..dense.nrows() {
|
||||||
let v = dense.index((i, j));
|
let v = dense.index((i, j));
|
||||||
if v != &T::zero() {
|
if v != &T::zero() {
|
||||||
row_idx.push(i);
|
row_idx.push(i);
|
||||||
|
@ -197,13 +200,15 @@ pub fn convert_dense_csc<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CscMatrix<T>
|
||||||
/// Converts a [`CsrMatrix`] to a [`CscMatrix`].
|
/// Converts a [`CsrMatrix`] to a [`CscMatrix`].
|
||||||
pub fn convert_csr_csc<T>(csr: &CsrMatrix<T>) -> CscMatrix<T>
|
pub fn convert_csr_csc<T>(csr: &CsrMatrix<T>) -> CscMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar
|
T: Scalar,
|
||||||
{
|
{
|
||||||
let (offsets, indices, values) = cs::transpose_cs(csr.nrows(),
|
let (offsets, indices, values) = cs::transpose_cs(
|
||||||
csr.ncols(),
|
csr.nrows(),
|
||||||
csr.row_offsets(),
|
csr.ncols(),
|
||||||
csr.col_indices(),
|
csr.row_offsets(),
|
||||||
csr.values());
|
csr.col_indices(),
|
||||||
|
csr.values(),
|
||||||
|
);
|
||||||
|
|
||||||
// TODO: Avoid data validity check?
|
// TODO: Avoid data validity check?
|
||||||
CscMatrix::try_from_csc_data(csr.nrows(), csr.ncols(), offsets, indices, values)
|
CscMatrix::try_from_csc_data(csr.nrows(), csr.ncols(), offsets, indices, values)
|
||||||
|
@ -212,27 +217,30 @@ where
|
||||||
|
|
||||||
/// Converts a [`CscMatrix`] to a [`CsrMatrix`].
|
/// Converts a [`CscMatrix`] to a [`CsrMatrix`].
|
||||||
pub fn convert_csc_csr<T>(csc: &CscMatrix<T>) -> CsrMatrix<T>
|
pub fn convert_csc_csr<T>(csc: &CscMatrix<T>) -> CsrMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar
|
T: Scalar,
|
||||||
{
|
{
|
||||||
let (offsets, indices, values) = cs::transpose_cs(csc.ncols(),
|
let (offsets, indices, values) = cs::transpose_cs(
|
||||||
csc.nrows(),
|
csc.ncols(),
|
||||||
csc.col_offsets(),
|
csc.nrows(),
|
||||||
csc.row_indices(),
|
csc.col_offsets(),
|
||||||
csc.values());
|
csc.row_indices(),
|
||||||
|
csc.values(),
|
||||||
|
);
|
||||||
|
|
||||||
// TODO: Avoid data validity check?
|
// TODO: Avoid data validity check?
|
||||||
CsrMatrix::try_from_csr_data(csc.nrows(), csc.ncols(), offsets, indices, values)
|
CsrMatrix::try_from_csr_data(csc.nrows(), csc.ncols(), offsets, indices, values)
|
||||||
.expect("Internal error: Invalid CSR data during CSC->CSR conversion")
|
.expect("Internal error: Invalid CSR data during CSC->CSR conversion")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_coo_cs<T>(major_dim: usize,
|
fn convert_coo_cs<T>(
|
||||||
major_indices: &[usize],
|
major_dim: usize,
|
||||||
minor_indices: &[usize],
|
major_indices: &[usize],
|
||||||
values: &[T])
|
minor_indices: &[usize],
|
||||||
-> (Vec<usize>, Vec<usize>, Vec<T>)
|
values: &[T],
|
||||||
|
) -> (Vec<usize>, Vec<usize>, Vec<T>)
|
||||||
where
|
where
|
||||||
T: Scalar + Zero
|
T: Scalar + Zero,
|
||||||
{
|
{
|
||||||
assert_eq!(major_indices.len(), minor_indices.len());
|
assert_eq!(major_indices.len(), minor_indices.len());
|
||||||
assert_eq!(minor_indices.len(), values.len());
|
assert_eq!(minor_indices.len(), values.len());
|
||||||
|
@ -416,4 +424,4 @@ fn combine_duplicates<T: Clone>(
|
||||||
produce_value(combined_value);
|
produce_value(combined_value);
|
||||||
i = j;
|
i = j;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,8 +45,7 @@ pub struct CooMatrix<T> {
|
||||||
values: Vec<T>,
|
values: Vec<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> CooMatrix<T>
|
impl<T> CooMatrix<T> {
|
||||||
{
|
|
||||||
/// Construct a zero COO matrix of the given dimensions.
|
/// Construct a zero COO matrix of the given dimensions.
|
||||||
///
|
///
|
||||||
/// Specifically, the collection of triplets - corresponding to explicitly stored entries -
|
/// Specifically, the collection of triplets - corresponding to explicitly stored entries -
|
||||||
|
@ -78,11 +77,13 @@ impl<T> CooMatrix<T>
|
||||||
use crate::SparseFormatErrorKind::*;
|
use crate::SparseFormatErrorKind::*;
|
||||||
if row_indices.len() != col_indices.len() {
|
if row_indices.len() != col_indices.len() {
|
||||||
return Err(SparseFormatError::from_kind_and_msg(
|
return Err(SparseFormatError::from_kind_and_msg(
|
||||||
InvalidStructure, "Number of row and col indices must be the same."
|
InvalidStructure,
|
||||||
|
"Number of row and col indices must be the same.",
|
||||||
));
|
));
|
||||||
} else if col_indices.len() != values.len() {
|
} else if col_indices.len() != values.len() {
|
||||||
return Err(SparseFormatError::from_kind_and_msg(
|
return Err(SparseFormatError::from_kind_and_msg(
|
||||||
InvalidStructure, "Number of col indices and values must be the same."
|
InvalidStructure,
|
||||||
|
"Number of col indices and values must be the same.",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,9 +91,15 @@ impl<T> CooMatrix<T>
|
||||||
let col_indices_in_bounds = col_indices.iter().all(|j| *j < ncols);
|
let col_indices_in_bounds = col_indices.iter().all(|j| *j < ncols);
|
||||||
|
|
||||||
if !row_indices_in_bounds {
|
if !row_indices_in_bounds {
|
||||||
Err(SparseFormatError::from_kind_and_msg(IndexOutOfBounds, "Row index out of bounds."))
|
Err(SparseFormatError::from_kind_and_msg(
|
||||||
|
IndexOutOfBounds,
|
||||||
|
"Row index out of bounds.",
|
||||||
|
))
|
||||||
} else if !col_indices_in_bounds {
|
} else if !col_indices_in_bounds {
|
||||||
Err(SparseFormatError::from_kind_and_msg(IndexOutOfBounds, "Col index out of bounds."))
|
Err(SparseFormatError::from_kind_and_msg(
|
||||||
|
IndexOutOfBounds,
|
||||||
|
"Col index out of bounds.",
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
nrows,
|
nrows,
|
||||||
|
|
|
@ -5,8 +5,8 @@ use num_traits::One;
|
||||||
|
|
||||||
use nalgebra::Scalar;
|
use nalgebra::Scalar;
|
||||||
|
|
||||||
use crate::{SparseEntry, SparseEntryMut};
|
|
||||||
use crate::pattern::SparsityPattern;
|
use crate::pattern::SparsityPattern;
|
||||||
|
use crate::{SparseEntry, SparseEntryMut};
|
||||||
|
|
||||||
/// An abstract compressed matrix.
|
/// An abstract compressed matrix.
|
||||||
///
|
///
|
||||||
|
@ -18,7 +18,7 @@ use crate::pattern::SparsityPattern;
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct CsMatrix<T> {
|
pub struct CsMatrix<T> {
|
||||||
sparsity_pattern: SparsityPattern,
|
sparsity_pattern: SparsityPattern,
|
||||||
values: Vec<T>
|
values: Vec<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> CsMatrix<T> {
|
impl<T> CsMatrix<T> {
|
||||||
|
@ -50,14 +50,22 @@ impl<T> CsMatrix<T> {
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) {
|
pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) {
|
||||||
let pattern = self.pattern();
|
let pattern = self.pattern();
|
||||||
(pattern.major_offsets(), pattern.minor_indices(), &self.values)
|
(
|
||||||
|
pattern.major_offsets(),
|
||||||
|
pattern.minor_indices(),
|
||||||
|
&self.values,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||||
let pattern = &mut self.sparsity_pattern;
|
let pattern = &mut self.sparsity_pattern;
|
||||||
(pattern.major_offsets(), pattern.minor_indices(), &mut self.values)
|
(
|
||||||
|
pattern.major_offsets(),
|
||||||
|
pattern.minor_indices(),
|
||||||
|
&mut self.values,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -66,9 +74,12 @@ impl<T> CsMatrix<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
|
pub fn from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>) -> Self {
|
||||||
-> Self {
|
assert_eq!(
|
||||||
assert_eq!(pattern.nnz(), values.len(), "Internal error: consumers should verify shape compatibility.");
|
pattern.nnz(),
|
||||||
|
values.len(),
|
||||||
|
"Internal error: consumers should verify shape compatibility."
|
||||||
|
);
|
||||||
Self {
|
Self {
|
||||||
sparsity_pattern: pattern,
|
sparsity_pattern: pattern,
|
||||||
values,
|
values,
|
||||||
|
@ -80,7 +91,7 @@ impl<T> CsMatrix<T> {
|
||||||
pub fn get_index_range(&self, row_index: usize) -> Option<Range<usize>> {
|
pub fn get_index_range(&self, row_index: usize) -> Option<Range<usize>> {
|
||||||
let row_begin = *self.sparsity_pattern.major_offsets().get(row_index)?;
|
let row_begin = *self.sparsity_pattern.major_offsets().get(row_index)?;
|
||||||
let row_end = *self.sparsity_pattern.major_offsets().get(row_index + 1)?;
|
let row_end = *self.sparsity_pattern.major_offsets().get(row_index + 1)?;
|
||||||
Some(row_begin .. row_end)
|
Some(row_begin..row_end)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn take_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
|
pub fn take_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
|
||||||
|
@ -105,13 +116,21 @@ impl<T> CsMatrix<T> {
|
||||||
let (_, minor_indices, values) = self.cs_data();
|
let (_, minor_indices, values) = self.cs_data();
|
||||||
let minor_indices = &minor_indices[row_range.clone()];
|
let minor_indices = &minor_indices[row_range.clone()];
|
||||||
let values = &values[row_range];
|
let values = &values[row_range];
|
||||||
get_entry_from_slices(self.pattern().minor_dim(), minor_indices, values, minor_index)
|
get_entry_from_slices(
|
||||||
|
self.pattern().minor_dim(),
|
||||||
|
minor_indices,
|
||||||
|
values,
|
||||||
|
minor_index,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a mutable entry for the given major/minor indices, or `None` if the indices are out
|
/// Returns a mutable entry for the given major/minor indices, or `None` if the indices are out
|
||||||
/// of bounds.
|
/// of bounds.
|
||||||
pub fn get_entry_mut(&mut self, major_index: usize, minor_index: usize)
|
pub fn get_entry_mut(
|
||||||
-> Option<SparseEntryMut<T>> {
|
&mut self,
|
||||||
|
major_index: usize,
|
||||||
|
minor_index: usize,
|
||||||
|
) -> Option<SparseEntryMut<T>> {
|
||||||
let row_range = self.get_index_range(major_index)?;
|
let row_range = self.get_index_range(major_index)?;
|
||||||
let minor_dim = self.pattern().minor_dim();
|
let minor_dim = self.pattern().minor_dim();
|
||||||
let (_, minor_indices, values) = self.cs_data_mut();
|
let (_, minor_indices, values) = self.cs_data_mut();
|
||||||
|
@ -126,7 +145,7 @@ impl<T> CsMatrix<T> {
|
||||||
Some(CsLane {
|
Some(CsLane {
|
||||||
minor_indices: &minor_indices[range.clone()],
|
minor_indices: &minor_indices[range.clone()],
|
||||||
values: &values[range],
|
values: &values[range],
|
||||||
minor_dim: self.pattern().minor_dim()
|
minor_dim: self.pattern().minor_dim(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,7 +157,7 @@ impl<T> CsMatrix<T> {
|
||||||
Some(CsLaneMut {
|
Some(CsLaneMut {
|
||||||
minor_dim,
|
minor_dim,
|
||||||
minor_indices: &minor_indices[range.clone()],
|
minor_indices: &minor_indices[range.clone()],
|
||||||
values: &mut values[range]
|
values: &mut values[range],
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -156,7 +175,7 @@ impl<T> CsMatrix<T> {
|
||||||
pub fn filter<P>(&self, predicate: P) -> Self
|
pub fn filter<P>(&self, predicate: P) -> Self
|
||||||
where
|
where
|
||||||
T: Clone,
|
T: Clone,
|
||||||
P: Fn(usize, usize, &T) -> bool
|
P: Fn(usize, usize, &T) -> bool,
|
||||||
{
|
{
|
||||||
let (major_dim, minor_dim) = (self.pattern().major_dim(), self.pattern().minor_dim());
|
let (major_dim, minor_dim) = (self.pattern().major_dim(), self.pattern().minor_dim());
|
||||||
let mut new_offsets = Vec::with_capacity(self.pattern().major_dim() + 1);
|
let mut new_offsets = Vec::with_capacity(self.pattern().major_dim() + 1);
|
||||||
|
@ -180,16 +199,17 @@ impl<T> CsMatrix<T> {
|
||||||
major_dim,
|
major_dim,
|
||||||
minor_dim,
|
minor_dim,
|
||||||
new_offsets,
|
new_offsets,
|
||||||
new_indices)
|
new_indices,
|
||||||
.expect("Internal error: Sparsity pattern must always be valid.");
|
)
|
||||||
|
.expect("Internal error: Sparsity pattern must always be valid.");
|
||||||
|
|
||||||
Self::from_pattern_and_values(new_pattern, new_values)
|
Self::from_pattern_and_values(new_pattern, new_values)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the diagonal of the matrix as a sparse matrix.
|
/// Returns the diagonal of the matrix as a sparse matrix.
|
||||||
pub fn diagonal_as_matrix(&self) -> Self
|
pub fn diagonal_as_matrix(&self) -> Self
|
||||||
where
|
where
|
||||||
T: Clone
|
T: Clone,
|
||||||
{
|
{
|
||||||
// TODO: This might be faster with a binary search for each diagonal entry
|
// TODO: This might be faster with a binary search for each diagonal entry
|
||||||
self.filter(|i, j, _| i == j)
|
self.filter(|i, j, _| i == j)
|
||||||
|
@ -199,13 +219,13 @@ impl<T> CsMatrix<T> {
|
||||||
impl<T: Scalar + One> CsMatrix<T> {
|
impl<T: Scalar + One> CsMatrix<T> {
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn identity(n: usize) -> Self {
|
pub fn identity(n: usize) -> Self {
|
||||||
let offsets: Vec<_> = (0 ..= n).collect();
|
let offsets: Vec<_> = (0..=n).collect();
|
||||||
let indices: Vec<_> = (0 .. n).collect();
|
let indices: Vec<_> = (0..n).collect();
|
||||||
let values = vec![T::one(); n];
|
let values = vec![T::one(); n];
|
||||||
|
|
||||||
// TODO: We should skip checks here
|
// TODO: We should skip checks here
|
||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices)
|
let pattern =
|
||||||
.unwrap();
|
SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices).unwrap();
|
||||||
Self::from_pattern_and_values(pattern, values)
|
Self::from_pattern_and_values(pattern, values)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -214,7 +234,8 @@ fn get_entry_from_slices<'a, T>(
|
||||||
minor_dim: usize,
|
minor_dim: usize,
|
||||||
minor_indices: &'a [usize],
|
minor_indices: &'a [usize],
|
||||||
values: &'a [T],
|
values: &'a [T],
|
||||||
global_minor_index: usize) -> Option<SparseEntry<'a, T>> {
|
global_minor_index: usize,
|
||||||
|
) -> Option<SparseEntry<'a, T>> {
|
||||||
let local_index = minor_indices.binary_search(&global_minor_index);
|
let local_index = minor_indices.binary_search(&global_minor_index);
|
||||||
if let Ok(local_index) = local_index {
|
if let Ok(local_index) = local_index {
|
||||||
Some(SparseEntry::NonZero(&values[local_index]))
|
Some(SparseEntry::NonZero(&values[local_index]))
|
||||||
|
@ -229,7 +250,8 @@ fn get_mut_entry_from_slices<'a, T>(
|
||||||
minor_dim: usize,
|
minor_dim: usize,
|
||||||
minor_indices: &'a [usize],
|
minor_indices: &'a [usize],
|
||||||
values: &'a mut [T],
|
values: &'a mut [T],
|
||||||
global_minor_indices: usize) -> Option<SparseEntryMut<'a, T>> {
|
global_minor_indices: usize,
|
||||||
|
) -> Option<SparseEntryMut<'a, T>> {
|
||||||
let local_index = minor_indices.binary_search(&global_minor_indices);
|
let local_index = minor_indices.binary_search(&global_minor_indices);
|
||||||
if let Ok(local_index) = local_index {
|
if let Ok(local_index) = local_index {
|
||||||
Some(SparseEntryMut::NonZero(&mut values[local_index]))
|
Some(SparseEntryMut::NonZero(&mut values[local_index]))
|
||||||
|
@ -244,14 +266,14 @@ fn get_mut_entry_from_slices<'a, T>(
|
||||||
pub struct CsLane<'a, T> {
|
pub struct CsLane<'a, T> {
|
||||||
minor_dim: usize,
|
minor_dim: usize,
|
||||||
minor_indices: &'a [usize],
|
minor_indices: &'a [usize],
|
||||||
values: &'a [T]
|
values: &'a [T],
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub struct CsLaneMut<'a, T> {
|
pub struct CsLaneMut<'a, T> {
|
||||||
minor_dim: usize,
|
minor_dim: usize,
|
||||||
minor_indices: &'a [usize],
|
minor_indices: &'a [usize],
|
||||||
values: &'a mut [T]
|
values: &'a mut [T],
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CsLaneIter<'a, T> {
|
pub struct CsLaneIter<'a, T> {
|
||||||
|
@ -266,14 +288,14 @@ impl<'a, T> CsLaneIter<'a, T> {
|
||||||
Self {
|
Self {
|
||||||
current_lane_idx: 0,
|
current_lane_idx: 0,
|
||||||
pattern,
|
pattern,
|
||||||
remaining_values: values
|
remaining_values: values,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> Iterator for CsLaneIter<'a, T>
|
impl<'a, T> Iterator for CsLaneIter<'a, T>
|
||||||
where
|
where
|
||||||
T: 'a
|
T: 'a,
|
||||||
{
|
{
|
||||||
type Item = CsLane<'a, T>;
|
type Item = CsLane<'a, T>;
|
||||||
|
|
||||||
|
@ -284,13 +306,13 @@ impl<'a, T> Iterator for CsLaneIter<'a, T>
|
||||||
if let Some(minor_indices) = lane {
|
if let Some(minor_indices) = lane {
|
||||||
let count = minor_indices.len();
|
let count = minor_indices.len();
|
||||||
let values_in_lane = &self.remaining_values[..count];
|
let values_in_lane = &self.remaining_values[..count];
|
||||||
self.remaining_values = &self.remaining_values[count ..];
|
self.remaining_values = &self.remaining_values[count..];
|
||||||
self.current_lane_idx += 1;
|
self.current_lane_idx += 1;
|
||||||
|
|
||||||
Some(CsLane {
|
Some(CsLane {
|
||||||
minor_dim,
|
minor_dim,
|
||||||
minor_indices,
|
minor_indices,
|
||||||
values: values_in_lane
|
values: values_in_lane,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
@ -310,14 +332,14 @@ impl<'a, T> CsLaneIterMut<'a, T> {
|
||||||
Self {
|
Self {
|
||||||
current_lane_idx: 0,
|
current_lane_idx: 0,
|
||||||
pattern,
|
pattern,
|
||||||
remaining_values: values
|
remaining_values: values,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> Iterator for CsLaneIterMut<'a, T>
|
impl<'a, T> Iterator for CsLaneIterMut<'a, T>
|
||||||
where
|
where
|
||||||
T: 'a
|
T: 'a,
|
||||||
{
|
{
|
||||||
type Item = CsLaneMut<'a, T>;
|
type Item = CsLaneMut<'a, T>;
|
||||||
|
|
||||||
|
@ -336,7 +358,7 @@ impl<'a, T> Iterator for CsLaneIterMut<'a, T>
|
||||||
Some(CsLaneMut {
|
Some(CsLaneMut {
|
||||||
minor_dim,
|
minor_dim,
|
||||||
minor_indices,
|
minor_indices,
|
||||||
values: values_in_lane
|
values: values_in_lane,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
@ -375,10 +397,11 @@ macro_rules! impl_cs_lane_common_methods {
|
||||||
self.minor_dim,
|
self.minor_dim,
|
||||||
self.minor_indices,
|
self.minor_indices,
|
||||||
self.values,
|
self.values,
|
||||||
global_col_index)
|
global_col_index,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_cs_lane_common_methods!(CsLane<'a, T>);
|
impl_cs_lane_common_methods!(CsLane<'a, T>);
|
||||||
|
@ -394,10 +417,12 @@ impl<'a, T> CsLaneMut<'a, T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_entry_mut(&mut self, global_minor_index: usize) -> Option<SparseEntryMut<T>> {
|
pub fn get_entry_mut(&mut self, global_minor_index: usize) -> Option<SparseEntryMut<T>> {
|
||||||
get_mut_entry_from_slices(self.minor_dim,
|
get_mut_entry_from_slices(
|
||||||
self.minor_indices,
|
self.minor_dim,
|
||||||
self.values,
|
self.minor_indices,
|
||||||
global_minor_index)
|
self.values,
|
||||||
|
global_minor_index,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -405,7 +430,7 @@ impl<'a, T> CsLaneMut<'a, T> {
|
||||||
/// TODO: This doesn't belong here.
|
/// TODO: This doesn't belong here.
|
||||||
struct UninitVec<T> {
|
struct UninitVec<T> {
|
||||||
vec: Vec<T>,
|
vec: Vec<T>,
|
||||||
len: usize
|
len: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> UninitVec<T> {
|
impl<T> UninitVec<T> {
|
||||||
|
@ -414,7 +439,7 @@ impl<T> UninitVec<T> {
|
||||||
vec: Vec::with_capacity(len),
|
vec: Vec::with_capacity(len),
|
||||||
// We need to store len separately, because for zero-sized types,
|
// We need to store len separately, because for zero-sized types,
|
||||||
// Vec::with_capacity(len) does not give vec.capacity() == len
|
// Vec::with_capacity(len) does not give vec.capacity() == len
|
||||||
len
|
len,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -440,14 +465,14 @@ impl<T> UninitVec<T> {
|
||||||
/// This means that major and minor roles are switched. This is used for converting between CSR
|
/// This means that major and minor roles are switched. This is used for converting between CSR
|
||||||
/// and CSC formats.
|
/// and CSC formats.
|
||||||
pub fn transpose_cs<T>(
|
pub fn transpose_cs<T>(
|
||||||
major_dim: usize,
|
major_dim: usize,
|
||||||
minor_dim: usize,
|
minor_dim: usize,
|
||||||
source_major_offsets: &[usize],
|
source_major_offsets: &[usize],
|
||||||
source_minor_indices: &[usize],
|
source_minor_indices: &[usize],
|
||||||
values: &[T])
|
values: &[T],
|
||||||
-> (Vec<usize>, Vec<usize>, Vec<T>)
|
) -> (Vec<usize>, Vec<usize>, Vec<T>)
|
||||||
where
|
where
|
||||||
T: Scalar
|
T: Scalar,
|
||||||
{
|
{
|
||||||
assert_eq!(source_major_offsets.len(), major_dim + 1);
|
assert_eq!(source_major_offsets.len(), major_dim + 1);
|
||||||
assert_eq!(source_minor_indices.len(), values.len());
|
assert_eq!(source_minor_indices.len(), values.len());
|
||||||
|
@ -470,18 +495,20 @@ where
|
||||||
// Keep track of how many entries we have placed in each target major lane
|
// Keep track of how many entries we have placed in each target major lane
|
||||||
let mut current_target_major_counts = vec![0; minor_dim];
|
let mut current_target_major_counts = vec![0; minor_dim];
|
||||||
|
|
||||||
for source_major_idx in 0 .. major_dim {
|
for source_major_idx in 0..major_dim {
|
||||||
let source_lane_begin = source_major_offsets[source_major_idx];
|
let source_lane_begin = source_major_offsets[source_major_idx];
|
||||||
let source_lane_end = source_major_offsets[source_major_idx + 1];
|
let source_lane_end = source_major_offsets[source_major_idx + 1];
|
||||||
let source_lane_indices = &source_minor_indices[source_lane_begin .. source_lane_end];
|
let source_lane_indices = &source_minor_indices[source_lane_begin..source_lane_end];
|
||||||
let source_lane_values = &values[source_lane_begin .. source_lane_end];
|
let source_lane_values = &values[source_lane_begin..source_lane_end];
|
||||||
|
|
||||||
for (&source_minor_idx, val) in source_lane_indices.iter().zip(source_lane_values) {
|
for (&source_minor_idx, val) in source_lane_indices.iter().zip(source_lane_values) {
|
||||||
// Compute the offset in the target data for this particular source entry
|
// Compute the offset in the target data for this particular source entry
|
||||||
let target_lane_count = &mut current_target_major_counts[source_minor_idx];
|
let target_lane_count = &mut current_target_major_counts[source_minor_idx];
|
||||||
let entry_offset = target_offsets[source_minor_idx] + *target_lane_count;
|
let entry_offset = target_offsets[source_minor_idx] + *target_lane_count;
|
||||||
target_indices[entry_offset] = source_major_idx;
|
target_indices[entry_offset] = source_major_idx;
|
||||||
unsafe { target_values.set(entry_offset, val.inlined_clone()); }
|
unsafe {
|
||||||
|
target_values.set(entry_offset, val.inlined_clone());
|
||||||
|
}
|
||||||
*target_lane_count += 1;
|
*target_lane_count += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,14 +3,14 @@
|
||||||
//! This is the module-level documentation. See [`CscMatrix`] for the main documentation of the
|
//! This is the module-level documentation. See [`CscMatrix`] for the main documentation of the
|
||||||
//! CSC implementation.
|
//! CSC implementation.
|
||||||
|
|
||||||
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
|
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
|
||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut};
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
|
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||||
|
|
||||||
use std::slice::{IterMut, Iter};
|
|
||||||
use num_traits::{One};
|
|
||||||
use nalgebra::Scalar;
|
use nalgebra::Scalar;
|
||||||
|
use num_traits::One;
|
||||||
|
use std::slice::{Iter, IterMut};
|
||||||
|
|
||||||
/// A CSC representation of a sparse matrix.
|
/// A CSC representation of a sparse matrix.
|
||||||
///
|
///
|
||||||
|
@ -130,7 +130,7 @@ impl<T> CscMatrix<T> {
|
||||||
/// Create a zero CSC matrix with no explicitly stored entries.
|
/// Create a zero CSC matrix with no explicitly stored entries.
|
||||||
pub fn zeros(nrows: usize, ncols: usize) -> Self {
|
pub fn zeros(nrows: usize, ncols: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
cs: CsMatrix::new(ncols, nrows)
|
cs: CsMatrix::new(ncols, nrows),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -196,8 +196,12 @@ impl<T> CscMatrix<T> {
|
||||||
values: Vec<T>,
|
values: Vec<T>,
|
||||||
) -> Result<Self, SparseFormatError> {
|
) -> Result<Self, SparseFormatError> {
|
||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||||
num_cols, num_rows, col_offsets, row_indices)
|
num_cols,
|
||||||
.map_err(pattern_format_error_to_csc_error)?;
|
num_rows,
|
||||||
|
col_offsets,
|
||||||
|
row_indices,
|
||||||
|
)
|
||||||
|
.map_err(pattern_format_error_to_csc_error)?;
|
||||||
Self::try_from_pattern_and_values(pattern, values)
|
Self::try_from_pattern_and_values(pattern, values)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,16 +209,19 @@ impl<T> CscMatrix<T> {
|
||||||
///
|
///
|
||||||
/// Returns an error if the number of values does not match the number of minor indices
|
/// Returns an error if the number of values does not match the number of minor indices
|
||||||
/// in the pattern.
|
/// in the pattern.
|
||||||
pub fn try_from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
|
pub fn try_from_pattern_and_values(
|
||||||
-> Result<Self, SparseFormatError> {
|
pattern: SparsityPattern,
|
||||||
|
values: Vec<T>,
|
||||||
|
) -> Result<Self, SparseFormatError> {
|
||||||
if pattern.nnz() == values.len() {
|
if pattern.nnz() == values.len() {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
cs: CsMatrix::from_pattern_and_values(pattern, values)
|
cs: CsMatrix::from_pattern_and_values(pattern, values),
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
Err(SparseFormatError::from_kind_and_msg(
|
Err(SparseFormatError::from_kind_and_msg(
|
||||||
SparseFormatErrorKind::InvalidStructure,
|
SparseFormatErrorKind::InvalidStructure,
|
||||||
"Number of values and row indices must be the same"))
|
"Number of values and row indices must be the same",
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -239,7 +246,7 @@ impl<T> CscMatrix<T> {
|
||||||
pub fn triplet_iter(&self) -> CscTripletIter<T> {
|
pub fn triplet_iter(&self) -> CscTripletIter<T> {
|
||||||
CscTripletIter {
|
CscTripletIter {
|
||||||
pattern_iter: self.pattern().entries(),
|
pattern_iter: self.pattern().entries(),
|
||||||
values_iter: self.values().iter()
|
values_iter: self.values().iter(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -270,7 +277,7 @@ impl<T> CscMatrix<T> {
|
||||||
let (pattern, values) = self.cs.pattern_and_values_mut();
|
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||||
CscTripletIterMut {
|
CscTripletIterMut {
|
||||||
pattern_iter: pattern.entries(),
|
pattern_iter: pattern.entries(),
|
||||||
values_mut_iter: values.iter_mut()
|
values_mut_iter: values.iter_mut(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -281,8 +288,7 @@ impl<T> CscMatrix<T> {
|
||||||
/// Panics if column index is out of bounds.
|
/// Panics if column index is out of bounds.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn col(&self, index: usize) -> CscCol<T> {
|
pub fn col(&self, index: usize) -> CscCol<T> {
|
||||||
self.get_col(index)
|
self.get_col(index).expect("Row index must be in bounds")
|
||||||
.expect("Row index must be in bounds")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mutable column access for the given column index.
|
/// Mutable column access for the given column index.
|
||||||
|
@ -299,23 +305,19 @@ impl<T> CscMatrix<T> {
|
||||||
/// Return the column at the given column index, or `None` if out of bounds.
|
/// Return the column at the given column index, or `None` if out of bounds.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn get_col(&self, index: usize) -> Option<CscCol<T>> {
|
pub fn get_col(&self, index: usize) -> Option<CscCol<T>> {
|
||||||
self.cs
|
self.cs.get_lane(index).map(|lane| CscCol { lane })
|
||||||
.get_lane(index)
|
|
||||||
.map(|lane| CscCol { lane })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mutable column access for the given column index, or `None` if out of bounds.
|
/// Mutable column access for the given column index, or `None` if out of bounds.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn get_col_mut(&mut self, index: usize) -> Option<CscColMut<T>> {
|
pub fn get_col_mut(&mut self, index: usize) -> Option<CscColMut<T>> {
|
||||||
self.cs
|
self.cs.get_lane_mut(index).map(|lane| CscColMut { lane })
|
||||||
.get_lane_mut(index)
|
|
||||||
.map(|lane| CscColMut { lane })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An iterator over columns in the matrix.
|
/// An iterator over columns in the matrix.
|
||||||
pub fn col_iter(&self) -> CscColIter<T> {
|
pub fn col_iter(&self) -> CscColIter<T> {
|
||||||
CscColIter {
|
CscColIter {
|
||||||
lane_iter: CsLaneIter::new(self.pattern(), self.values())
|
lane_iter: CsLaneIter::new(self.pattern(), self.values()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -323,7 +325,7 @@ impl<T> CscMatrix<T> {
|
||||||
pub fn col_iter_mut(&mut self) -> CscColIterMut<T> {
|
pub fn col_iter_mut(&mut self) -> CscColIterMut<T> {
|
||||||
let (pattern, values) = self.cs.pattern_and_values_mut();
|
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||||
CscColIterMut {
|
CscColIterMut {
|
||||||
lane_iter: CsLaneIterMut::new(pattern, values)
|
lane_iter: CsLaneIterMut::new(pattern, values),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -397,8 +399,11 @@ impl<T> CscMatrix<T> {
|
||||||
///
|
///
|
||||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
/// stored row entries for the given column.
|
/// stored row entries for the given column.
|
||||||
pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize)
|
pub fn get_entry_mut(
|
||||||
-> Option<SparseEntryMut<T>> {
|
&mut self,
|
||||||
|
row_index: usize,
|
||||||
|
col_index: usize,
|
||||||
|
) -> Option<SparseEntryMut<T>> {
|
||||||
self.cs.get_entry_mut(col_index, row_index)
|
self.cs.get_entry_mut(col_index, row_index)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -444,11 +449,15 @@ impl<T> CscMatrix<T> {
|
||||||
pub fn filter<P>(&self, predicate: P) -> Self
|
pub fn filter<P>(&self, predicate: P) -> Self
|
||||||
where
|
where
|
||||||
T: Clone,
|
T: Clone,
|
||||||
P: Fn(usize, usize, &T) -> bool
|
P: Fn(usize, usize, &T) -> bool,
|
||||||
{
|
{
|
||||||
// Note: Predicate uses (row, col, value), so we have to switch around since
|
// Note: Predicate uses (row, col, value), so we have to switch around since
|
||||||
// cs uses (major, minor, value)
|
// cs uses (major, minor, value)
|
||||||
Self { cs: self.cs.filter(|col_idx, row_idx, v| predicate(row_idx, col_idx, v)) }
|
Self {
|
||||||
|
cs: self
|
||||||
|
.cs
|
||||||
|
.filter(|col_idx, row_idx, v| predicate(row_idx, col_idx, v)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a new matrix representing the upper triangular part of this matrix.
|
/// Returns a new matrix representing the upper triangular part of this matrix.
|
||||||
|
@ -456,7 +465,7 @@ impl<T> CscMatrix<T> {
|
||||||
/// The result includes the diagonal of the matrix.
|
/// The result includes the diagonal of the matrix.
|
||||||
pub fn upper_triangle(&self) -> Self
|
pub fn upper_triangle(&self) -> Self
|
||||||
where
|
where
|
||||||
T: Clone
|
T: Clone,
|
||||||
{
|
{
|
||||||
self.filter(|i, j, _| i <= j)
|
self.filter(|i, j, _| i <= j)
|
||||||
}
|
}
|
||||||
|
@ -466,7 +475,7 @@ impl<T> CscMatrix<T> {
|
||||||
/// The result includes the diagonal of the matrix.
|
/// The result includes the diagonal of the matrix.
|
||||||
pub fn lower_triangle(&self) -> Self
|
pub fn lower_triangle(&self) -> Self
|
||||||
where
|
where
|
||||||
T: Clone
|
T: Clone,
|
||||||
{
|
{
|
||||||
self.filter(|i, j, _| i >= j)
|
self.filter(|i, j, _| i >= j)
|
||||||
}
|
}
|
||||||
|
@ -474,15 +483,17 @@ impl<T> CscMatrix<T> {
|
||||||
/// Returns the diagonal of the matrix as a sparse matrix.
|
/// Returns the diagonal of the matrix as a sparse matrix.
|
||||||
pub fn diagonal_as_csc(&self) -> Self
|
pub fn diagonal_as_csc(&self) -> Self
|
||||||
where
|
where
|
||||||
T: Clone
|
T: Clone,
|
||||||
{
|
{
|
||||||
Self { cs: self.cs.diagonal_as_matrix() }
|
Self {
|
||||||
|
cs: self.cs.diagonal_as_matrix(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> CscMatrix<T>
|
impl<T> CscMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar
|
T: Scalar,
|
||||||
{
|
{
|
||||||
/// Compute the transpose of the matrix.
|
/// Compute the transpose of the matrix.
|
||||||
pub fn transpose(&self) -> CscMatrix<T> {
|
pub fn transpose(&self) -> CscMatrix<T> {
|
||||||
|
@ -495,7 +506,7 @@ impl<T: Scalar + One> CscMatrix<T> {
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn identity(n: usize) -> Self {
|
pub fn identity(n: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
cs: CsMatrix::identity(n)
|
cs: CsMatrix::identity(n),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -505,30 +516,34 @@ impl<T: Scalar + One> CscMatrix<T> {
|
||||||
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
||||||
/// not lanes, major and minor dimensions.
|
/// not lanes, major and minor dimensions.
|
||||||
fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseFormatError {
|
fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseFormatError {
|
||||||
use SparsityPatternFormatError::*;
|
|
||||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
|
||||||
use SparseFormatError as E;
|
use SparseFormatError as E;
|
||||||
use SparseFormatErrorKind as K;
|
use SparseFormatErrorKind as K;
|
||||||
|
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||||
|
use SparsityPatternFormatError::*;
|
||||||
|
|
||||||
match err {
|
match err {
|
||||||
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||||
K::InvalidStructure,
|
K::InvalidStructure,
|
||||||
"Length of col offset array is not equal to ncols + 1."),
|
"Length of col offset array is not equal to ncols + 1.",
|
||||||
|
),
|
||||||
InvalidOffsetFirstLast => E::from_kind_and_msg(
|
InvalidOffsetFirstLast => E::from_kind_and_msg(
|
||||||
K::InvalidStructure,
|
K::InvalidStructure,
|
||||||
"First or last col offset is inconsistent with format specification."),
|
"First or last col offset is inconsistent with format specification.",
|
||||||
|
),
|
||||||
NonmonotonicOffsets => E::from_kind_and_msg(
|
NonmonotonicOffsets => E::from_kind_and_msg(
|
||||||
K::InvalidStructure,
|
K::InvalidStructure,
|
||||||
"Col offsets are not monotonically increasing."),
|
"Col offsets are not monotonically increasing.",
|
||||||
|
),
|
||||||
NonmonotonicMinorIndices => E::from_kind_and_msg(
|
NonmonotonicMinorIndices => E::from_kind_and_msg(
|
||||||
K::InvalidStructure,
|
K::InvalidStructure,
|
||||||
"Row indices are not monotonically increasing (sorted) within each column."),
|
"Row indices are not monotonically increasing (sorted) within each column.",
|
||||||
MinorIndexOutOfBounds => E::from_kind_and_msg(
|
),
|
||||||
K::IndexOutOfBounds,
|
MinorIndexOutOfBounds => {
|
||||||
"Row indices are out of bounds."),
|
E::from_kind_and_msg(K::IndexOutOfBounds, "Row indices are out of bounds.")
|
||||||
PatternDuplicateEntry => E::from_kind_and_msg(
|
}
|
||||||
K::DuplicateEntry,
|
PatternDuplicateEntry => {
|
||||||
"Matrix data contains duplicate entries."),
|
E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -536,7 +551,7 @@ fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseF
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct CscTripletIter<'a, T> {
|
pub struct CscTripletIter<'a, T> {
|
||||||
pattern_iter: SparsityPatternIter<'a>,
|
pattern_iter: SparsityPatternIter<'a>,
|
||||||
values_iter: Iter<'a, T>
|
values_iter: Iter<'a, T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Clone> CscTripletIter<'a, T> {
|
impl<'a, T: Clone> CscTripletIter<'a, T> {
|
||||||
|
@ -545,7 +560,7 @@ impl<'a, T: Clone> CscTripletIter<'a, T> {
|
||||||
/// The triplet iterator returns references to the values. This method adapts the iterator
|
/// The triplet iterator returns references to the values. This method adapts the iterator
|
||||||
/// so that the values are cloned.
|
/// so that the values are cloned.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn cloned_values(self) -> impl 'a + Iterator<Item=(usize, usize, T)> {
|
pub fn cloned_values(self) -> impl 'a + Iterator<Item = (usize, usize, T)> {
|
||||||
self.map(|(i, j, v)| (i, j, v.clone()))
|
self.map(|(i, j, v)| (i, j, v.clone()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -559,7 +574,7 @@ impl<'a, T> Iterator for CscTripletIter<'a, T> {
|
||||||
|
|
||||||
match (next_entry, next_value) {
|
match (next_entry, next_value) {
|
||||||
(Some((i, j)), Some(v)) => Some((j, i, v)),
|
(Some((i, j)), Some(v)) => Some((j, i, v)),
|
||||||
_ => None
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -568,7 +583,7 @@ impl<'a, T> Iterator for CscTripletIter<'a, T> {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct CscTripletIterMut<'a, T> {
|
pub struct CscTripletIterMut<'a, T> {
|
||||||
pattern_iter: SparsityPatternIter<'a>,
|
pattern_iter: SparsityPatternIter<'a>,
|
||||||
values_mut_iter: IterMut<'a, T>
|
values_mut_iter: IterMut<'a, T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
|
impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
|
||||||
|
@ -581,7 +596,7 @@ impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
|
||||||
|
|
||||||
match (next_entry, next_value) {
|
match (next_entry, next_value) {
|
||||||
(Some((i, j)), Some(v)) => Some((j, i, v)),
|
(Some((i, j)), Some(v)) => Some((j, i, v)),
|
||||||
_ => None
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -589,7 +604,7 @@ impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
|
||||||
/// An immutable representation of a column in a CSC matrix.
|
/// An immutable representation of a column in a CSC matrix.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct CscCol<'a, T> {
|
pub struct CscCol<'a, T> {
|
||||||
lane: CsLane<'a, T>
|
lane: CsLane<'a, T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A mutable representation of a column in a CSC matrix.
|
/// A mutable representation of a column in a CSC matrix.
|
||||||
|
@ -598,7 +613,7 @@ pub struct CscCol<'a, T> {
|
||||||
/// to the column cannot be modified.
|
/// to the column cannot be modified.
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub struct CscColMut<'a, T> {
|
pub struct CscColMut<'a, T> {
|
||||||
lane: CsLaneMut<'a, T>
|
lane: CsLaneMut<'a, T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implement the methods common to both CscCol and CscColMut
|
/// Implement the methods common to both CscCol and CscColMut
|
||||||
|
@ -637,7 +652,7 @@ macro_rules! impl_csc_col_common_methods {
|
||||||
self.lane.get_entry(global_row_index)
|
self.lane.get_entry(global_row_index)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_csc_col_common_methods!(CscCol<'a, T>);
|
impl_csc_col_common_methods!(CscCol<'a, T>);
|
||||||
|
@ -666,33 +681,29 @@ impl<'a, T> CscColMut<'a, T> {
|
||||||
|
|
||||||
/// Column iterator for [CscMatrix](struct.CscMatrix.html).
|
/// Column iterator for [CscMatrix](struct.CscMatrix.html).
|
||||||
pub struct CscColIter<'a, T> {
|
pub struct CscColIter<'a, T> {
|
||||||
lane_iter: CsLaneIter<'a, T>
|
lane_iter: CsLaneIter<'a, T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> Iterator for CscColIter<'a, T> {
|
impl<'a, T> Iterator for CscColIter<'a, T> {
|
||||||
type Item = CscCol<'a, T>;
|
type Item = CscCol<'a, T>;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
self.lane_iter
|
self.lane_iter.next().map(|lane| CscCol { lane })
|
||||||
.next()
|
|
||||||
.map(|lane| CscCol { lane })
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mutable column iterator for [CscMatrix](struct.CscMatrix.html).
|
/// Mutable column iterator for [CscMatrix](struct.CscMatrix.html).
|
||||||
pub struct CscColIterMut<'a, T> {
|
pub struct CscColIterMut<'a, T> {
|
||||||
lane_iter: CsLaneIterMut<'a, T>
|
lane_iter: CsLaneIterMut<'a, T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> Iterator for CscColIterMut<'a, T>
|
impl<'a, T> Iterator for CscColIterMut<'a, T>
|
||||||
where
|
where
|
||||||
T: 'a
|
T: 'a,
|
||||||
{
|
{
|
||||||
type Item = CscColMut<'a, T>;
|
type Item = CscColMut<'a, T>;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
self.lane_iter
|
self.lane_iter.next().map(|lane| CscColMut { lane })
|
||||||
.next()
|
|
||||||
.map(|lane| CscColMut { lane })
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,15 +2,15 @@
|
||||||
//!
|
//!
|
||||||
//! This is the module-level documentation. See [`CsrMatrix`] for the main documentation of the
|
//! This is the module-level documentation. See [`CsrMatrix`] for the main documentation of the
|
||||||
//! CSC implementation.
|
//! CSC implementation.
|
||||||
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
|
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
use crate::cs::{CsMatrix, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut};
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
|
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||||
|
|
||||||
use nalgebra::Scalar;
|
use nalgebra::Scalar;
|
||||||
use num_traits::{One};
|
use num_traits::One;
|
||||||
|
|
||||||
use std::slice::{IterMut, Iter};
|
use std::slice::{Iter, IterMut};
|
||||||
|
|
||||||
/// A CSR representation of a sparse matrix.
|
/// A CSR representation of a sparse matrix.
|
||||||
///
|
///
|
||||||
|
@ -130,7 +130,7 @@ impl<T> CsrMatrix<T> {
|
||||||
/// Create a zero CSR matrix with no explicitly stored entries.
|
/// Create a zero CSR matrix with no explicitly stored entries.
|
||||||
pub fn zeros(nrows: usize, ncols: usize) -> Self {
|
pub fn zeros(nrows: usize, ncols: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
cs: CsMatrix::new(nrows, ncols)
|
cs: CsMatrix::new(nrows, ncols),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,8 +198,12 @@ impl<T> CsrMatrix<T> {
|
||||||
values: Vec<T>,
|
values: Vec<T>,
|
||||||
) -> Result<Self, SparseFormatError> {
|
) -> Result<Self, SparseFormatError> {
|
||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||||
num_rows, num_cols, row_offsets, col_indices)
|
num_rows,
|
||||||
.map_err(pattern_format_error_to_csr_error)?;
|
num_cols,
|
||||||
|
row_offsets,
|
||||||
|
col_indices,
|
||||||
|
)
|
||||||
|
.map_err(pattern_format_error_to_csr_error)?;
|
||||||
Self::try_from_pattern_and_values(pattern, values)
|
Self::try_from_pattern_and_values(pattern, values)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,16 +211,19 @@ impl<T> CsrMatrix<T> {
|
||||||
///
|
///
|
||||||
/// Returns an error if the number of values does not match the number of minor indices
|
/// Returns an error if the number of values does not match the number of minor indices
|
||||||
/// in the pattern.
|
/// in the pattern.
|
||||||
pub fn try_from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
|
pub fn try_from_pattern_and_values(
|
||||||
-> Result<Self, SparseFormatError> {
|
pattern: SparsityPattern,
|
||||||
|
values: Vec<T>,
|
||||||
|
) -> Result<Self, SparseFormatError> {
|
||||||
if pattern.nnz() == values.len() {
|
if pattern.nnz() == values.len() {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
cs: CsMatrix::from_pattern_and_values(pattern, values)
|
cs: CsMatrix::from_pattern_and_values(pattern, values),
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
Err(SparseFormatError::from_kind_and_msg(
|
Err(SparseFormatError::from_kind_and_msg(
|
||||||
SparseFormatErrorKind::InvalidStructure,
|
SparseFormatErrorKind::InvalidStructure,
|
||||||
"Number of values and column indices must be the same"))
|
"Number of values and column indices must be the same",
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,7 +248,7 @@ impl<T> CsrMatrix<T> {
|
||||||
pub fn triplet_iter(&self) -> CsrTripletIter<T> {
|
pub fn triplet_iter(&self) -> CsrTripletIter<T> {
|
||||||
CsrTripletIter {
|
CsrTripletIter {
|
||||||
pattern_iter: self.pattern().entries(),
|
pattern_iter: self.pattern().entries(),
|
||||||
values_iter: self.values().iter()
|
values_iter: self.values().iter(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -272,7 +279,7 @@ impl<T> CsrMatrix<T> {
|
||||||
let (pattern, values) = self.cs.pattern_and_values_mut();
|
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||||
CsrTripletIterMut {
|
CsrTripletIterMut {
|
||||||
pattern_iter: pattern.entries(),
|
pattern_iter: pattern.entries(),
|
||||||
values_mut_iter: values.iter_mut()
|
values_mut_iter: values.iter_mut(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -283,8 +290,7 @@ impl<T> CsrMatrix<T> {
|
||||||
/// Panics if row index is out of bounds.
|
/// Panics if row index is out of bounds.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn row(&self, index: usize) -> CsrRow<T> {
|
pub fn row(&self, index: usize) -> CsrRow<T> {
|
||||||
self.get_row(index)
|
self.get_row(index).expect("Row index must be in bounds")
|
||||||
.expect("Row index must be in bounds")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mutable row access for the given row index.
|
/// Mutable row access for the given row index.
|
||||||
|
@ -301,23 +307,19 @@ impl<T> CsrMatrix<T> {
|
||||||
/// Return the row at the given row index, or `None` if out of bounds.
|
/// Return the row at the given row index, or `None` if out of bounds.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn get_row(&self, index: usize) -> Option<CsrRow<T>> {
|
pub fn get_row(&self, index: usize) -> Option<CsrRow<T>> {
|
||||||
self.cs
|
self.cs.get_lane(index).map(|lane| CsrRow { lane })
|
||||||
.get_lane(index)
|
|
||||||
.map(|lane| CsrRow { lane })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mutable row access for the given row index, or `None` if out of bounds.
|
/// Mutable row access for the given row index, or `None` if out of bounds.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<T>> {
|
pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<T>> {
|
||||||
self.cs
|
self.cs.get_lane_mut(index).map(|lane| CsrRowMut { lane })
|
||||||
.get_lane_mut(index)
|
|
||||||
.map(|lane| CsrRowMut { lane })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An iterator over rows in the matrix.
|
/// An iterator over rows in the matrix.
|
||||||
pub fn row_iter(&self) -> CsrRowIter<T> {
|
pub fn row_iter(&self) -> CsrRowIter<T> {
|
||||||
CsrRowIter {
|
CsrRowIter {
|
||||||
lane_iter: CsLaneIter::new(self.pattern(), self.values())
|
lane_iter: CsLaneIter::new(self.pattern(), self.values()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -399,8 +401,11 @@ impl<T> CsrMatrix<T> {
|
||||||
///
|
///
|
||||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
/// stored column entries for the given row.
|
/// stored column entries for the given row.
|
||||||
pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize)
|
pub fn get_entry_mut(
|
||||||
-> Option<SparseEntryMut<T>> {
|
&mut self,
|
||||||
|
row_index: usize,
|
||||||
|
col_index: usize,
|
||||||
|
) -> Option<SparseEntryMut<T>> {
|
||||||
self.cs.get_entry_mut(row_index, col_index)
|
self.cs.get_entry_mut(row_index, col_index)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -444,19 +449,23 @@ impl<T> CsrMatrix<T> {
|
||||||
/// Creates a sparse matrix that contains only the explicit entries decided by the
|
/// Creates a sparse matrix that contains only the explicit entries decided by the
|
||||||
/// given predicate.
|
/// given predicate.
|
||||||
pub fn filter<P>(&self, predicate: P) -> Self
|
pub fn filter<P>(&self, predicate: P) -> Self
|
||||||
where
|
where
|
||||||
T: Clone,
|
T: Clone,
|
||||||
P: Fn(usize, usize, &T) -> bool
|
P: Fn(usize, usize, &T) -> bool,
|
||||||
{
|
{
|
||||||
Self { cs: self.cs.filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)) }
|
Self {
|
||||||
|
cs: self
|
||||||
|
.cs
|
||||||
|
.filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a new matrix representing the upper triangular part of this matrix.
|
/// Returns a new matrix representing the upper triangular part of this matrix.
|
||||||
///
|
///
|
||||||
/// The result includes the diagonal of the matrix.
|
/// The result includes the diagonal of the matrix.
|
||||||
pub fn upper_triangle(&self) -> Self
|
pub fn upper_triangle(&self) -> Self
|
||||||
where
|
where
|
||||||
T: Clone
|
T: Clone,
|
||||||
{
|
{
|
||||||
self.filter(|i, j, _| i <= j)
|
self.filter(|i, j, _| i <= j)
|
||||||
}
|
}
|
||||||
|
@ -465,24 +474,26 @@ impl<T> CsrMatrix<T> {
|
||||||
///
|
///
|
||||||
/// The result includes the diagonal of the matrix.
|
/// The result includes the diagonal of the matrix.
|
||||||
pub fn lower_triangle(&self) -> Self
|
pub fn lower_triangle(&self) -> Self
|
||||||
where
|
where
|
||||||
T: Clone
|
T: Clone,
|
||||||
{
|
{
|
||||||
self.filter(|i, j, _| i >= j)
|
self.filter(|i, j, _| i >= j)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the diagonal of the matrix as a sparse matrix.
|
/// Returns the diagonal of the matrix as a sparse matrix.
|
||||||
pub fn diagonal_as_csr(&self) -> Self
|
pub fn diagonal_as_csr(&self) -> Self
|
||||||
where
|
where
|
||||||
T: Clone
|
T: Clone,
|
||||||
{
|
{
|
||||||
Self { cs: self.cs.diagonal_as_matrix() }
|
Self {
|
||||||
|
cs: self.cs.diagonal_as_matrix(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> CsrMatrix<T>
|
impl<T> CsrMatrix<T>
|
||||||
where
|
where
|
||||||
T: Scalar
|
T: Scalar,
|
||||||
{
|
{
|
||||||
/// Compute the transpose of the matrix.
|
/// Compute the transpose of the matrix.
|
||||||
pub fn transpose(&self) -> CsrMatrix<T> {
|
pub fn transpose(&self) -> CsrMatrix<T> {
|
||||||
|
@ -495,7 +506,7 @@ impl<T: Scalar + One> CsrMatrix<T> {
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn identity(n: usize) -> Self {
|
pub fn identity(n: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
cs: CsMatrix::identity(n)
|
cs: CsMatrix::identity(n),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -505,30 +516,34 @@ impl<T: Scalar + One> CsrMatrix<T> {
|
||||||
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
||||||
/// not lanes, major and minor dimensions.
|
/// not lanes, major and minor dimensions.
|
||||||
fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseFormatError {
|
fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseFormatError {
|
||||||
use SparsityPatternFormatError::*;
|
|
||||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
|
||||||
use SparseFormatError as E;
|
use SparseFormatError as E;
|
||||||
use SparseFormatErrorKind as K;
|
use SparseFormatErrorKind as K;
|
||||||
|
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||||
|
use SparsityPatternFormatError::*;
|
||||||
|
|
||||||
match err {
|
match err {
|
||||||
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||||
K::InvalidStructure,
|
K::InvalidStructure,
|
||||||
"Length of row offset array is not equal to nrows + 1."),
|
"Length of row offset array is not equal to nrows + 1.",
|
||||||
|
),
|
||||||
InvalidOffsetFirstLast => E::from_kind_and_msg(
|
InvalidOffsetFirstLast => E::from_kind_and_msg(
|
||||||
K::InvalidStructure,
|
K::InvalidStructure,
|
||||||
"First or last row offset is inconsistent with format specification."),
|
"First or last row offset is inconsistent with format specification.",
|
||||||
|
),
|
||||||
NonmonotonicOffsets => E::from_kind_and_msg(
|
NonmonotonicOffsets => E::from_kind_and_msg(
|
||||||
K::InvalidStructure,
|
K::InvalidStructure,
|
||||||
"Row offsets are not monotonically increasing."),
|
"Row offsets are not monotonically increasing.",
|
||||||
|
),
|
||||||
NonmonotonicMinorIndices => E::from_kind_and_msg(
|
NonmonotonicMinorIndices => E::from_kind_and_msg(
|
||||||
K::InvalidStructure,
|
K::InvalidStructure,
|
||||||
"Column indices are not monotonically increasing (sorted) within each row."),
|
"Column indices are not monotonically increasing (sorted) within each row.",
|
||||||
MinorIndexOutOfBounds => E::from_kind_and_msg(
|
),
|
||||||
K::IndexOutOfBounds,
|
MinorIndexOutOfBounds => {
|
||||||
"Column indices are out of bounds."),
|
E::from_kind_and_msg(K::IndexOutOfBounds, "Column indices are out of bounds.")
|
||||||
PatternDuplicateEntry => E::from_kind_and_msg(
|
}
|
||||||
K::DuplicateEntry,
|
PatternDuplicateEntry => {
|
||||||
"Matrix data contains duplicate entries."),
|
E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -536,7 +551,7 @@ fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseF
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct CsrTripletIter<'a, T> {
|
pub struct CsrTripletIter<'a, T> {
|
||||||
pattern_iter: SparsityPatternIter<'a>,
|
pattern_iter: SparsityPatternIter<'a>,
|
||||||
values_iter: Iter<'a, T>
|
values_iter: Iter<'a, T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Clone> CsrTripletIter<'a, T> {
|
impl<'a, T: Clone> CsrTripletIter<'a, T> {
|
||||||
|
@ -545,7 +560,7 @@ impl<'a, T: Clone> CsrTripletIter<'a, T> {
|
||||||
/// The triplet iterator returns references to the values. This method adapts the iterator
|
/// The triplet iterator returns references to the values. This method adapts the iterator
|
||||||
/// so that the values are cloned.
|
/// so that the values are cloned.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn cloned_values(self) -> impl 'a + Iterator<Item=(usize, usize, T)> {
|
pub fn cloned_values(self) -> impl 'a + Iterator<Item = (usize, usize, T)> {
|
||||||
self.map(|(i, j, v)| (i, j, v.clone()))
|
self.map(|(i, j, v)| (i, j, v.clone()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -559,7 +574,7 @@ impl<'a, T> Iterator for CsrTripletIter<'a, T> {
|
||||||
|
|
||||||
match (next_entry, next_value) {
|
match (next_entry, next_value) {
|
||||||
(Some((i, j)), Some(v)) => Some((i, j, v)),
|
(Some((i, j)), Some(v)) => Some((i, j, v)),
|
||||||
_ => None
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -568,7 +583,7 @@ impl<'a, T> Iterator for CsrTripletIter<'a, T> {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct CsrTripletIterMut<'a, T> {
|
pub struct CsrTripletIterMut<'a, T> {
|
||||||
pattern_iter: SparsityPatternIter<'a>,
|
pattern_iter: SparsityPatternIter<'a>,
|
||||||
values_mut_iter: IterMut<'a, T>
|
values_mut_iter: IterMut<'a, T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
|
impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
|
||||||
|
@ -581,7 +596,7 @@ impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
|
||||||
|
|
||||||
match (next_entry, next_value) {
|
match (next_entry, next_value) {
|
||||||
(Some((i, j)), Some(v)) => Some((i, j, v)),
|
(Some((i, j)), Some(v)) => Some((i, j, v)),
|
||||||
_ => None
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -589,7 +604,7 @@ impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
|
||||||
/// An immutable representation of a row in a CSR matrix.
|
/// An immutable representation of a row in a CSR matrix.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct CsrRow<'a, T> {
|
pub struct CsrRow<'a, T> {
|
||||||
lane: CsLane<'a, T>
|
lane: CsLane<'a, T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A mutable representation of a row in a CSR matrix.
|
/// A mutable representation of a row in a CSR matrix.
|
||||||
|
@ -598,7 +613,7 @@ pub struct CsrRow<'a, T> {
|
||||||
/// to the row cannot be modified.
|
/// to the row cannot be modified.
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub struct CsrRowMut<'a, T> {
|
pub struct CsrRowMut<'a, T> {
|
||||||
lane: CsLaneMut<'a, T>
|
lane: CsLaneMut<'a, T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implement the methods common to both CsrRow and CsrRowMut
|
/// Implement the methods common to both CsrRow and CsrRowMut
|
||||||
|
@ -638,7 +653,7 @@ macro_rules! impl_csr_row_common_methods {
|
||||||
self.lane.get_entry(global_col_index)
|
self.lane.get_entry(global_col_index)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_csr_row_common_methods!(CsrRow<'a, T>);
|
impl_csr_row_common_methods!(CsrRow<'a, T>);
|
||||||
|
@ -670,33 +685,29 @@ impl<'a, T> CsrRowMut<'a, T> {
|
||||||
|
|
||||||
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
||||||
pub struct CsrRowIter<'a, T> {
|
pub struct CsrRowIter<'a, T> {
|
||||||
lane_iter: CsLaneIter<'a, T>
|
lane_iter: CsLaneIter<'a, T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> Iterator for CsrRowIter<'a, T> {
|
impl<'a, T> Iterator for CsrRowIter<'a, T> {
|
||||||
type Item = CsrRow<'a, T>;
|
type Item = CsrRow<'a, T>;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
self.lane_iter
|
self.lane_iter.next().map(|lane| CsrRow { lane })
|
||||||
.next()
|
|
||||||
.map(|lane| CsrRow { lane })
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mutable row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
/// Mutable row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
||||||
pub struct CsrRowIterMut<'a, T> {
|
pub struct CsrRowIterMut<'a, T> {
|
||||||
lane_iter: CsLaneIterMut<'a, T>
|
lane_iter: CsLaneIterMut<'a, T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> Iterator for CsrRowIterMut<'a, T>
|
impl<'a, T> Iterator for CsrRowIterMut<'a, T>
|
||||||
where
|
where
|
||||||
T: 'a
|
T: 'a,
|
||||||
{
|
{
|
||||||
type Item = CsrRowMut<'a, T>;
|
type Item = CsrRowMut<'a, T>;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
self.lane_iter
|
self.lane_iter.next().map(|lane| CsrRowMut { lane })
|
||||||
.next()
|
|
||||||
.map(|lane| CsrRowMut { lane })
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
use crate::pattern::SparsityPattern;
|
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
use core::{mem, iter};
|
|
||||||
use nalgebra::{Scalar, RealField, DMatrixSlice, DMatrixSliceMut, DMatrix};
|
|
||||||
use std::fmt::{Display, Formatter};
|
|
||||||
use crate::ops::serial::spsolve_csc_lower_triangular;
|
use crate::ops::serial::spsolve_csc_lower_triangular;
|
||||||
use crate::ops::Op;
|
use crate::ops::Op;
|
||||||
|
use crate::pattern::SparsityPattern;
|
||||||
|
use core::{iter, mem};
|
||||||
|
use nalgebra::{DMatrix, DMatrixSlice, DMatrixSliceMut, RealField, Scalar};
|
||||||
|
use std::fmt::{Display, Formatter};
|
||||||
|
|
||||||
/// A symbolic sparse Cholesky factorization of a CSC matrix.
|
/// A symbolic sparse Cholesky factorization of a CSC matrix.
|
||||||
///
|
///
|
||||||
|
@ -15,7 +15,7 @@ pub struct CscSymbolicCholesky {
|
||||||
m_pattern: SparsityPattern,
|
m_pattern: SparsityPattern,
|
||||||
l_pattern: SparsityPattern,
|
l_pattern: SparsityPattern,
|
||||||
// u in this context is L^T, so that M = L L^T
|
// u in this context is L^T, so that M = L L^T
|
||||||
u_pattern: SparsityPattern
|
u_pattern: SparsityPattern,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CscSymbolicCholesky {
|
impl CscSymbolicCholesky {
|
||||||
|
@ -28,8 +28,11 @@ impl CscSymbolicCholesky {
|
||||||
///
|
///
|
||||||
/// Panics if the sparsity pattern is not square.
|
/// Panics if the sparsity pattern is not square.
|
||||||
pub fn factor(pattern: SparsityPattern) -> Self {
|
pub fn factor(pattern: SparsityPattern) -> Self {
|
||||||
assert_eq!(pattern.major_dim(), pattern.minor_dim(),
|
assert_eq!(
|
||||||
"Major and minor dimensions must be the same (square matrix).");
|
pattern.major_dim(),
|
||||||
|
pattern.minor_dim(),
|
||||||
|
"Major and minor dimensions must be the same (square matrix)."
|
||||||
|
);
|
||||||
let (l_pattern, u_pattern) = nonzero_pattern(&pattern);
|
let (l_pattern, u_pattern) = nonzero_pattern(&pattern);
|
||||||
Self {
|
Self {
|
||||||
m_pattern: pattern,
|
m_pattern: pattern,
|
||||||
|
@ -65,7 +68,7 @@ pub struct CscCholesky<T> {
|
||||||
l_factor: CscMatrix<T>,
|
l_factor: CscMatrix<T>,
|
||||||
u_pattern: SparsityPattern,
|
u_pattern: SparsityPattern,
|
||||||
work_x: Vec<T>,
|
work_x: Vec<T>,
|
||||||
work_c: Vec<usize>
|
work_c: Vec<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||||
|
@ -100,16 +103,20 @@ impl<T: RealField> CscCholesky<T> {
|
||||||
///
|
///
|
||||||
/// Panics if the number of values differ from the number of non-zeros of the sparsity pattern
|
/// Panics if the number of values differ from the number of non-zeros of the sparsity pattern
|
||||||
/// of the matrix that was symbolically factored.
|
/// of the matrix that was symbolically factored.
|
||||||
pub fn factor_numerical(symbolic: CscSymbolicCholesky, values: &[T])
|
pub fn factor_numerical(
|
||||||
-> Result<Self, CholeskyError>
|
symbolic: CscSymbolicCholesky,
|
||||||
{
|
values: &[T],
|
||||||
assert_eq!(symbolic.l_pattern.nnz(), symbolic.u_pattern.nnz(),
|
) -> Result<Self, CholeskyError> {
|
||||||
"u is just the transpose of l, so should have the same nnz");
|
assert_eq!(
|
||||||
|
symbolic.l_pattern.nnz(),
|
||||||
|
symbolic.u_pattern.nnz(),
|
||||||
|
"u is just the transpose of l, so should have the same nnz"
|
||||||
|
);
|
||||||
|
|
||||||
let l_nnz = symbolic.l_pattern.nnz();
|
let l_nnz = symbolic.l_pattern.nnz();
|
||||||
let l_values = vec![T::zero(); l_nnz];
|
let l_values = vec![T::zero(); l_nnz];
|
||||||
let l_factor = CscMatrix::try_from_pattern_and_values(symbolic.l_pattern, l_values)
|
let l_factor =
|
||||||
.unwrap();
|
CscMatrix::try_from_pattern_and_values(symbolic.l_pattern, l_values).unwrap();
|
||||||
|
|
||||||
let (nrows, ncols) = (l_factor.nrows(), l_factor.ncols());
|
let (nrows, ncols) = (l_factor.nrows(), l_factor.ncols());
|
||||||
|
|
||||||
|
@ -169,7 +176,7 @@ impl<T: RealField> CscCholesky<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the Cholesky factor `L`.
|
/// Returns the Cholesky factor `L`.
|
||||||
pub fn take_l(self) -> CscMatrix<T> {
|
pub fn take_l(self) -> CscMatrix<T> {
|
||||||
self.l_factor
|
self.l_factor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -229,11 +236,9 @@ impl<T: RealField> CscCholesky<T> {
|
||||||
|
|
||||||
{
|
{
|
||||||
let (offsets, _, values) = self.l_factor.csc_data_mut();
|
let (offsets, _, values) = self.l_factor.csc_data_mut();
|
||||||
*values
|
*values.get_unchecked_mut(*offsets.get_unchecked(k)) = denom;
|
||||||
.get_unchecked_mut(*offsets.get_unchecked(k)) = denom;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
let mut col_k = self.l_factor.col_mut(k);
|
let mut col_k = self.l_factor.col_mut(k);
|
||||||
let (col_k_rows, col_k_values) = col_k.rows_and_values_mut();
|
let (col_k_rows, col_k_values) = col_k.rows_and_values_mut();
|
||||||
let col_k_entries = col_k_rows.iter().zip(col_k_values);
|
let col_k_entries = col_k_rows.iter().zip(col_k_values);
|
||||||
|
@ -269,19 +274,16 @@ impl<T: RealField> CscCholesky<T> {
|
||||||
/// # Panics
|
/// # Panics
|
||||||
///
|
///
|
||||||
/// Panics if `b` is not square.
|
/// Panics if `b` is not square.
|
||||||
pub fn solve_mut<'a>(&'a self, b: impl Into<DMatrixSliceMut<'a, T>>)
|
pub fn solve_mut<'a>(&'a self, b: impl Into<DMatrixSliceMut<'a, T>>) {
|
||||||
{
|
|
||||||
let expect_msg = "If the Cholesky factorization succeeded,\
|
let expect_msg = "If the Cholesky factorization succeeded,\
|
||||||
then the triangular solve should never fail";
|
then the triangular solve should never fail";
|
||||||
// Solve LY = B
|
// Solve LY = B
|
||||||
let mut y = b.into();
|
let mut y = b.into();
|
||||||
spsolve_csc_lower_triangular(Op::NoOp(self.l()), &mut y)
|
spsolve_csc_lower_triangular(Op::NoOp(self.l()), &mut y).expect(expect_msg);
|
||||||
.expect(expect_msg);
|
|
||||||
|
|
||||||
// Solve L^T X = Y
|
// Solve L^T X = Y
|
||||||
let mut x = y;
|
let mut x = y;
|
||||||
spsolve_csc_lower_triangular(Op::Transpose(self.l()), &mut x)
|
spsolve_csc_lower_triangular(Op::Transpose(self.l()), &mut x).expect(expect_msg);
|
||||||
.expect(expect_msg);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,8 +335,8 @@ fn nonzero_pattern(m: &SparsityPattern) -> (SparsityPattern, SparsityPattern) {
|
||||||
col_offsets.push(rows.len());
|
col_offsets.push(rows.len());
|
||||||
}
|
}
|
||||||
|
|
||||||
let u_pattern = SparsityPattern::try_from_offsets_and_indices(nrows, ncols, col_offsets, rows)
|
let u_pattern =
|
||||||
.unwrap();
|
SparsityPattern::try_from_offsets_and_indices(nrows, ncols, col_offsets, rows).unwrap();
|
||||||
|
|
||||||
// TODO: Avoid this transpose?
|
// TODO: Avoid this transpose?
|
||||||
let l_pattern = u_pattern.transpose();
|
let l_pattern = u_pattern.transpose();
|
||||||
|
@ -368,4 +370,4 @@ fn elimination_tree(pattern: &SparsityPattern) -> Vec<usize> {
|
||||||
}
|
}
|
||||||
|
|
||||||
forest
|
forest
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,4 +3,4 @@
|
||||||
//! Currently, the only factorization provided here is the [`CscCholesky`] factorization.
|
//! Currently, the only factorization provided here is the [`CscCholesky`] factorization.
|
||||||
mod cholesky;
|
mod cholesky;
|
||||||
|
|
||||||
pub use cholesky::*;
|
pub use cholesky::*;
|
||||||
|
|
|
@ -135,13 +135,13 @@
|
||||||
#![deny(unused_results)]
|
#![deny(unused_results)]
|
||||||
#![deny(missing_docs)]
|
#![deny(missing_docs)]
|
||||||
|
|
||||||
|
pub mod convert;
|
||||||
pub mod coo;
|
pub mod coo;
|
||||||
pub mod csc;
|
pub mod csc;
|
||||||
pub mod csr;
|
pub mod csr;
|
||||||
pub mod pattern;
|
|
||||||
pub mod ops;
|
|
||||||
pub mod convert;
|
|
||||||
pub mod factorization;
|
pub mod factorization;
|
||||||
|
pub mod ops;
|
||||||
|
pub mod pattern;
|
||||||
|
|
||||||
pub(crate) mod cs;
|
pub(crate) mod cs;
|
||||||
|
|
||||||
|
@ -151,16 +151,16 @@ pub mod proptest;
|
||||||
#[cfg(feature = "compare")]
|
#[cfg(feature = "compare")]
|
||||||
mod matrixcompare;
|
mod matrixcompare;
|
||||||
|
|
||||||
|
use num_traits::Zero;
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use num_traits::Zero;
|
|
||||||
|
|
||||||
/// Errors produced by functions that expect well-formed sparse format data.
|
/// Errors produced by functions that expect well-formed sparse format data.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct SparseFormatError {
|
pub struct SparseFormatError {
|
||||||
kind: SparseFormatErrorKind,
|
kind: SparseFormatErrorKind,
|
||||||
// Currently we only use an underlying error for generating the `Display` impl
|
// Currently we only use an underlying error for generating the `Display` impl
|
||||||
error: Box<dyn Error>
|
error: Box<dyn Error>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SparseFormatError {
|
impl SparseFormatError {
|
||||||
|
@ -170,10 +170,7 @@ impl SparseFormatError {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn from_kind_and_error(kind: SparseFormatErrorKind, error: Box<dyn Error>) -> Self {
|
pub(crate) fn from_kind_and_error(kind: SparseFormatErrorKind, error: Box<dyn Error>) -> Self {
|
||||||
Self {
|
Self { kind, error }
|
||||||
kind,
|
|
||||||
error
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper functionality for more conveniently creating errors.
|
/// Helper functionality for more conveniently creating errors.
|
||||||
|
@ -221,7 +218,7 @@ pub enum SparseEntry<'a, T> {
|
||||||
/// is explicitly stored (a so-called "explicit zero").
|
/// is explicitly stored (a so-called "explicit zero").
|
||||||
NonZero(&'a T),
|
NonZero(&'a T),
|
||||||
/// The entry is implicitly zero, i.e. it is not explicitly stored.
|
/// The entry is implicitly zero, i.e. it is not explicitly stored.
|
||||||
Zero
|
Zero,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Clone + Zero> SparseEntry<'a, T> {
|
impl<'a, T: Clone + Zero> SparseEntry<'a, T> {
|
||||||
|
@ -232,7 +229,7 @@ impl<'a, T: Clone + Zero> SparseEntry<'a, T> {
|
||||||
pub fn to_value(self) -> T {
|
pub fn to_value(self) -> T {
|
||||||
match self {
|
match self {
|
||||||
SparseEntry::NonZero(value) => value.clone(),
|
SparseEntry::NonZero(value) => value.clone(),
|
||||||
SparseEntry::Zero => T::zero()
|
SparseEntry::Zero => T::zero(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -248,7 +245,7 @@ pub enum SparseEntryMut<'a, T> {
|
||||||
/// is explicitly stored (a so-called "explicit zero").
|
/// is explicitly stored (a so-called "explicit zero").
|
||||||
NonZero(&'a mut T),
|
NonZero(&'a mut T),
|
||||||
/// The entry is implicitly zero i.e. it is not explicitly stored.
|
/// The entry is implicitly zero i.e. it is not explicitly stored.
|
||||||
Zero
|
Zero,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
|
impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
|
||||||
|
@ -259,7 +256,7 @@ impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
|
||||||
pub fn to_value(self) -> T {
|
pub fn to_value(self) -> T {
|
||||||
match self {
|
match self {
|
||||||
SparseEntryMut::NonZero(value) => value.clone(),
|
SparseEntryMut::NonZero(value) => value.clone(),
|
||||||
SparseEntryMut::Zero => T::zero()
|
SparseEntryMut::Zero => T::zero(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,36 +1,38 @@
|
||||||
//! Implements core traits for use with `matrixcompare`.
|
//! Implements core traits for use with `matrixcompare`.
|
||||||
use crate::csr::CsrMatrix;
|
use crate::coo::CooMatrix;
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
|
use crate::csr::CsrMatrix;
|
||||||
use matrixcompare_core;
|
use matrixcompare_core;
|
||||||
use matrixcompare_core::{Access, SparseAccess};
|
use matrixcompare_core::{Access, SparseAccess};
|
||||||
use crate::coo::CooMatrix;
|
|
||||||
|
|
||||||
macro_rules! impl_matrix_for_csr_csc {
|
macro_rules! impl_matrix_for_csr_csc {
|
||||||
($MatrixType:ident) => {
|
($MatrixType:ident) => {
|
||||||
impl<T: Clone> SparseAccess<T> for $MatrixType<T> {
|
impl<T: Clone> SparseAccess<T> for $MatrixType<T> {
|
||||||
fn nnz(&self) -> usize {
|
fn nnz(&self) -> usize {
|
||||||
$MatrixType::nnz(self)
|
$MatrixType::nnz(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
|
||||||
|
self.triplet_iter()
|
||||||
|
.map(|(i, j, v)| (i, j, v.clone()))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
|
impl<T: Clone> matrixcompare_core::Matrix<T> for $MatrixType<T> {
|
||||||
self.triplet_iter().map(|(i, j, v)| (i, j, v.clone())).collect()
|
fn rows(&self) -> usize {
|
||||||
}
|
self.nrows()
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Clone> matrixcompare_core::Matrix<T> for $MatrixType<T> {
|
fn cols(&self) -> usize {
|
||||||
fn rows(&self) -> usize {
|
self.ncols()
|
||||||
self.nrows()
|
}
|
||||||
}
|
|
||||||
|
|
||||||
fn cols(&self) -> usize {
|
fn access(&self) -> Access<T> {
|
||||||
self.ncols()
|
Access::Sparse(self)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
fn access(&self) -> Access<T> {
|
|
||||||
Access::Sparse(self)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_matrix_for_csr_csc!(CsrMatrix);
|
impl_matrix_for_csr_csc!(CsrMatrix);
|
||||||
|
@ -42,7 +44,9 @@ impl<T: Clone> SparseAccess<T> for CooMatrix<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
|
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
|
||||||
self.triplet_iter().map(|(i, j, v)| (i, j, v.clone())).collect()
|
self.triplet_iter()
|
||||||
|
.map(|(i, j, v)| (i, j, v.clone()))
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,4 +62,4 @@ impl<T: Clone> matrixcompare_core::Matrix<T> for CooMatrix<T> {
|
||||||
fn access(&self) -> Access<T> {
|
fn access(&self) -> Access<T> {
|
||||||
Access::Sparse(self)
|
Access::Sparse(self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +1,20 @@
|
||||||
use crate::csr::CsrMatrix;
|
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
|
use crate::csr::CsrMatrix;
|
||||||
|
|
||||||
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg};
|
use crate::ops::serial::{
|
||||||
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_csr_pattern, spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense, spmm_csc_pattern};
|
spadd_csc_prealloc, spadd_csr_prealloc, spadd_pattern, spmm_csc_dense, spmm_csc_pattern,
|
||||||
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, MatrixMN, Dim,
|
spmm_csc_prealloc, spmm_csr_dense, spmm_csr_pattern, spmm_csr_prealloc,
|
||||||
Dynamic, DefaultAllocator, U1};
|
};
|
||||||
use nalgebra::allocator::{Allocator};
|
use crate::ops::Op;
|
||||||
use nalgebra::constraint::{DimEq, ShapeConstraint};
|
use nalgebra::allocator::Allocator;
|
||||||
use num_traits::{Zero, One};
|
|
||||||
use crate::ops::{Op};
|
|
||||||
use nalgebra::base::storage::Storage;
|
use nalgebra::base::storage::Storage;
|
||||||
|
use nalgebra::constraint::{DimEq, ShapeConstraint};
|
||||||
|
use nalgebra::{
|
||||||
|
ClosedAdd, ClosedDiv, ClosedMul, ClosedSub, DefaultAllocator, Dim, Dynamic, Matrix, MatrixMN,
|
||||||
|
Scalar, U1,
|
||||||
|
};
|
||||||
|
use num_traits::{One, Zero};
|
||||||
|
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Neg, Sub};
|
||||||
|
|
||||||
/// Helper macro for implementing binary operators for different matrix types
|
/// Helper macro for implementing binary operators for different matrix types
|
||||||
/// See below for usage.
|
/// See below for usage.
|
||||||
|
@ -188,7 +193,7 @@ macro_rules! impl_neg {
|
||||||
($matrix_type:ident) => {
|
($matrix_type:ident) => {
|
||||||
impl<T> Neg for $matrix_type<T>
|
impl<T> Neg for $matrix_type<T>
|
||||||
where
|
where
|
||||||
T: Scalar + Neg<Output=T>
|
T: Scalar + Neg<Output = T>,
|
||||||
{
|
{
|
||||||
type Output = $matrix_type<T>;
|
type Output = $matrix_type<T>;
|
||||||
|
|
||||||
|
@ -202,7 +207,7 @@ macro_rules! impl_neg {
|
||||||
|
|
||||||
impl<'a, T> Neg for &'a $matrix_type<T>
|
impl<'a, T> Neg for &'a $matrix_type<T>
|
||||||
where
|
where
|
||||||
T: Scalar + Neg<Output=T>
|
T: Scalar + Neg<Output = T>,
|
||||||
{
|
{
|
||||||
type Output = $matrix_type<T>;
|
type Output = $matrix_type<T>;
|
||||||
|
|
||||||
|
@ -211,10 +216,10 @@ macro_rules! impl_neg {
|
||||||
// obtain both the sparsity pattern and values from the matrix,
|
// obtain both the sparsity pattern and values from the matrix,
|
||||||
// and then modify the values before creating a new matrix from the pattern
|
// and then modify the values before creating a new matrix from the pattern
|
||||||
// and negated values.
|
// and negated values.
|
||||||
- self.clone()
|
-self.clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_neg!(CsrMatrix);
|
impl_neg!(CsrMatrix);
|
||||||
|
@ -323,4 +328,4 @@ macro_rules! impl_spmm_cs_dense {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_spmm_cs_dense!(CsrMatrix, spmm_csr_dense);
|
impl_spmm_cs_dense!(CsrMatrix, spmm_csr_dense);
|
||||||
impl_spmm_cs_dense!(CscMatrix, spmm_csc_dense);
|
impl_spmm_cs_dense!(CscMatrix, spmm_csc_dense);
|
||||||
|
|
|
@ -148,13 +148,14 @@ impl<T> Op<T> {
|
||||||
pub fn as_ref(&self) -> Op<&T> {
|
pub fn as_ref(&self) -> Op<&T> {
|
||||||
match self {
|
match self {
|
||||||
Op::NoOp(obj) => Op::NoOp(&obj),
|
Op::NoOp(obj) => Op::NoOp(&obj),
|
||||||
Op::Transpose(obj) => Op::Transpose(&obj)
|
Op::Transpose(obj) => Op::Transpose(&obj),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Converts the underlying data type.
|
/// Converts the underlying data type.
|
||||||
pub fn convert<U>(self) -> Op<U>
|
pub fn convert<U>(self) -> Op<U>
|
||||||
where T: Into<U>
|
where
|
||||||
|
T: Into<U>,
|
||||||
{
|
{
|
||||||
self.map_same_op(T::into)
|
self.map_same_op(T::into)
|
||||||
}
|
}
|
||||||
|
@ -163,7 +164,7 @@ impl<T> Op<T> {
|
||||||
pub fn map_same_op<U, F: FnOnce(T) -> U>(self, f: F) -> Op<U> {
|
pub fn map_same_op<U, F: FnOnce(T) -> U>(self, f: F) -> Op<U> {
|
||||||
match self {
|
match self {
|
||||||
Op::NoOp(obj) => Op::NoOp(f(obj)),
|
Op::NoOp(obj) => Op::NoOp(f(obj)),
|
||||||
Op::Transpose(obj) => Op::Transpose(f(obj))
|
Op::Transpose(obj) => Op::Transpose(f(obj)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -181,7 +182,7 @@ impl<T> Op<T> {
|
||||||
pub fn transposed(self) -> Self {
|
pub fn transposed(self) -> Self {
|
||||||
match self {
|
match self {
|
||||||
Op::NoOp(obj) => Op::Transpose(obj),
|
Op::NoOp(obj) => Op::Transpose(obj),
|
||||||
Op::Transpose(obj) => Op::NoOp(obj)
|
Op::Transpose(obj) => Op::NoOp(obj),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -191,4 +192,3 @@ impl<T> From<T> for Op<T> {
|
||||||
Self::NoOp(obj)
|
Self::NoOp(obj)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
use crate::cs::CsMatrix;
|
use crate::cs::CsMatrix;
|
||||||
|
use crate::ops::serial::{OperationError, OperationErrorKind};
|
||||||
use crate::ops::Op;
|
use crate::ops::Op;
|
||||||
use crate::ops::serial::{OperationErrorKind, OperationError};
|
|
||||||
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice};
|
|
||||||
use num_traits::{Zero, One};
|
|
||||||
use crate::SparseEntryMut;
|
use crate::SparseEntryMut;
|
||||||
|
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
|
||||||
|
use num_traits::{One, Zero};
|
||||||
|
|
||||||
fn spmm_cs_unexpected_entry() -> OperationError {
|
fn spmm_cs_unexpected_entry() -> OperationError {
|
||||||
OperationError::from_kind_and_message(
|
OperationError::from_kind_and_message(
|
||||||
OperationErrorKind::InvalidPattern,
|
OperationErrorKind::InvalidPattern,
|
||||||
String::from("Found unexpected entry that is not present in `c`."))
|
String::from("Found unexpected entry that is not present in `c`."),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper functionality for implementing CSR/CSC SPMM.
|
/// Helper functionality for implementing CSR/CSC SPMM.
|
||||||
|
@ -24,12 +25,12 @@ pub fn spmm_cs_prealloc<T>(
|
||||||
c: &mut CsMatrix<T>,
|
c: &mut CsMatrix<T>,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
a: &CsMatrix<T>,
|
a: &CsMatrix<T>,
|
||||||
b: &CsMatrix<T>)
|
b: &CsMatrix<T>,
|
||||||
-> Result<(), OperationError>
|
) -> Result<(), OperationError>
|
||||||
where
|
where
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
{
|
{
|
||||||
for i in 0 .. c.pattern().major_dim() {
|
for i in 0..c.pattern().major_dim() {
|
||||||
let a_lane_i = a.get_lane(i).unwrap();
|
let a_lane_i = a.get_lane(i).unwrap();
|
||||||
let mut c_lane_i = c.get_lane_mut(i).unwrap();
|
let mut c_lane_i = c.get_lane_mut(i).unwrap();
|
||||||
for c_ij in c_lane_i.values_mut() {
|
for c_ij in c_lane_i.values_mut() {
|
||||||
|
@ -42,14 +43,15 @@ pub fn spmm_cs_prealloc<T>(
|
||||||
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
|
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
|
||||||
for (j, b_kj) in b_lane_k.minor_indices().iter().zip(b_lane_k.values()) {
|
for (j, b_kj) in b_lane_k.minor_indices().iter().zip(b_lane_k.values()) {
|
||||||
// Determine the location in C to append the value
|
// Determine the location in C to append the value
|
||||||
let (c_local_idx, _) = c_lane_i_cols.iter()
|
let (c_local_idx, _) = c_lane_i_cols
|
||||||
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.find(|(_, c_col)| *c_col == j)
|
.find(|(_, c_col)| *c_col == j)
|
||||||
.ok_or_else(spmm_cs_unexpected_entry)?;
|
.ok_or_else(spmm_cs_unexpected_entry)?;
|
||||||
|
|
||||||
c_lane_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone();
|
c_lane_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone();
|
||||||
c_lane_i_cols = &c_lane_i_cols[c_local_idx ..];
|
c_lane_i_cols = &c_lane_i_cols[c_local_idx..];
|
||||||
c_lane_i_values = &mut c_lane_i_values[c_local_idx ..];
|
c_lane_i_values = &mut c_lane_i_values[c_local_idx..];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -60,17 +62,19 @@ pub fn spmm_cs_prealloc<T>(
|
||||||
fn spadd_cs_unexpected_entry() -> OperationError {
|
fn spadd_cs_unexpected_entry() -> OperationError {
|
||||||
OperationError::from_kind_and_message(
|
OperationError::from_kind_and_message(
|
||||||
OperationErrorKind::InvalidPattern,
|
OperationErrorKind::InvalidPattern,
|
||||||
String::from("Found entry in `op(a)` that is not present in `c`."))
|
String::from("Found entry in `op(a)` that is not present in `c`."),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper functionality for implementing CSR/CSC SPADD.
|
/// Helper functionality for implementing CSR/CSC SPADD.
|
||||||
pub fn spadd_cs_prealloc<T>(beta: T,
|
pub fn spadd_cs_prealloc<T>(
|
||||||
c: &mut CsMatrix<T>,
|
beta: T,
|
||||||
alpha: T,
|
c: &mut CsMatrix<T>,
|
||||||
a: Op<&CsMatrix<T>>)
|
alpha: T,
|
||||||
-> Result<(), OperationError>
|
a: Op<&CsMatrix<T>>,
|
||||||
where
|
) -> Result<(), OperationError>
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
{
|
{
|
||||||
match a {
|
match a {
|
||||||
Op::NoOp(a) => {
|
Op::NoOp(a) => {
|
||||||
|
@ -88,13 +92,14 @@ pub fn spadd_cs_prealloc<T>(beta: T,
|
||||||
// TODO: Use exponential search instead of linear search.
|
// TODO: Use exponential search instead of linear search.
|
||||||
// If C has substantially more entries in the row than A, then a line search
|
// If C has substantially more entries in the row than A, then a line search
|
||||||
// will needlessly visit many entries in C.
|
// will needlessly visit many entries in C.
|
||||||
let (c_idx, _) = c_minors.iter()
|
let (c_idx, _) = c_minors
|
||||||
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.find(|(_, c_col)| *c_col == a_col)
|
.find(|(_, c_col)| *c_col == a_col)
|
||||||
.ok_or_else(spadd_cs_unexpected_entry)?;
|
.ok_or_else(spadd_cs_unexpected_entry)?;
|
||||||
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
|
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
|
||||||
c_minors = &c_minors[c_idx ..];
|
c_minors = &c_minors[c_idx..];
|
||||||
c_vals = &mut c_vals[c_idx ..];
|
c_vals = &mut c_vals[c_idx..];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -110,7 +115,7 @@ pub fn spadd_cs_prealloc<T>(beta: T,
|
||||||
let a_val = a_val.inlined_clone();
|
let a_val = a_val.inlined_clone();
|
||||||
let alpha = alpha.inlined_clone();
|
let alpha = alpha.inlined_clone();
|
||||||
match c.get_entry_mut(j, i).unwrap() {
|
match c.get_entry_mut(j, i).unwrap() {
|
||||||
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
|
SparseEntryMut::NonZero(c_ji) => *c_ji += alpha * a_val,
|
||||||
SparseEntryMut::Zero => return Err(spadd_cs_unexpected_entry()),
|
SparseEntryMut::Zero => return Err(spadd_cs_unexpected_entry()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -124,13 +129,14 @@ pub fn spadd_cs_prealloc<T>(beta: T,
|
||||||
///
|
///
|
||||||
/// The implementation essentially assumes that `a` is a CSR matrix. To use it with CSC matrices,
|
/// The implementation essentially assumes that `a` is a CSR matrix. To use it with CSC matrices,
|
||||||
/// the transposed operation must be specified for the CSC matrix.
|
/// the transposed operation must be specified for the CSC matrix.
|
||||||
pub fn spmm_cs_dense<T>(beta: T,
|
pub fn spmm_cs_dense<T>(
|
||||||
mut c: DMatrixSliceMut<T>,
|
beta: T,
|
||||||
alpha: T,
|
mut c: DMatrixSliceMut<T>,
|
||||||
a: Op<&CsMatrix<T>>,
|
alpha: T,
|
||||||
b: Op<DMatrixSlice<T>>)
|
a: Op<&CsMatrix<T>>,
|
||||||
where
|
b: Op<DMatrixSlice<T>>,
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
) where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
{
|
{
|
||||||
match a {
|
match a {
|
||||||
Op::NoOp(a) => {
|
Op::NoOp(a) => {
|
||||||
|
@ -139,17 +145,17 @@ pub fn spmm_cs_dense<T>(beta: T,
|
||||||
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.lane_iter()) {
|
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.lane_iter()) {
|
||||||
let mut dot_ij = T::zero();
|
let mut dot_ij = T::zero();
|
||||||
for (&k, a_ik) in a_row_i.minor_indices().iter().zip(a_row_i.values()) {
|
for (&k, a_ik) in a_row_i.minor_indices().iter().zip(a_row_i.values()) {
|
||||||
let b_contrib =
|
let b_contrib = match b {
|
||||||
match b {
|
Op::NoOp(ref b) => b.index((k, j)),
|
||||||
Op::NoOp(ref b) => b.index((k, j)),
|
Op::Transpose(ref b) => b.index((j, k)),
|
||||||
Op::Transpose(ref b) => b.index((j, k))
|
};
|
||||||
};
|
|
||||||
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
|
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
|
||||||
}
|
}
|
||||||
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;
|
*c_ij = beta.inlined_clone() * c_ij.inlined_clone()
|
||||||
|
+ alpha.inlined_clone() * dot_ij;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
Op::Transpose(a) => {
|
Op::Transpose(a) => {
|
||||||
// In this case, we have to pre-multiply C by beta
|
// In this case, we have to pre-multiply C by beta
|
||||||
c *= beta;
|
c *= beta;
|
||||||
|
@ -165,17 +171,16 @@ pub fn spmm_cs_dense<T>(beta: T,
|
||||||
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
|
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
|
||||||
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
|
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
Op::Transpose(ref b) => {
|
Op::Transpose(ref b) => {
|
||||||
let b_col_k = b.column(k);
|
let b_col_k = b.column(k);
|
||||||
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
|
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
|
||||||
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
|
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
use crate::ops::Op;
|
use crate::ops::serial::cs::{spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc};
|
||||||
use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc};
|
|
||||||
use crate::ops::serial::{OperationError, OperationErrorKind};
|
use crate::ops::serial::{OperationError, OperationErrorKind};
|
||||||
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice, RealField};
|
use crate::ops::Op;
|
||||||
use num_traits::{Zero, One};
|
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, RealField, Scalar};
|
||||||
|
use num_traits::{One, Zero};
|
||||||
|
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
|
||||||
|
@ -12,25 +12,27 @@ use std::borrow::Cow;
|
||||||
/// # Panics
|
/// # Panics
|
||||||
///
|
///
|
||||||
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||||
pub fn spmm_csc_dense<'a, T>(beta: T,
|
pub fn spmm_csc_dense<'a, T>(
|
||||||
c: impl Into<DMatrixSliceMut<'a, T>>,
|
beta: T,
|
||||||
alpha: T,
|
c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||||
a: Op<&CscMatrix<T>>,
|
alpha: T,
|
||||||
b: Op<impl Into<DMatrixSlice<'a, T>>>)
|
a: Op<&CscMatrix<T>>,
|
||||||
where
|
b: Op<impl Into<DMatrixSlice<'a, T>>>,
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
) where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
{
|
{
|
||||||
let b = b.convert();
|
let b = b.convert();
|
||||||
spmm_csc_dense_(beta, c.into(), alpha, a, b)
|
spmm_csc_dense_(beta, c.into(), alpha, a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spmm_csc_dense_<T>(beta: T,
|
fn spmm_csc_dense_<T>(
|
||||||
c: DMatrixSliceMut<T>,
|
beta: T,
|
||||||
alpha: T,
|
c: DMatrixSliceMut<T>,
|
||||||
a: Op<&CscMatrix<T>>,
|
alpha: T,
|
||||||
b: Op<DMatrixSlice<T>>)
|
a: Op<&CscMatrix<T>>,
|
||||||
where
|
b: Op<DMatrixSlice<T>>,
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
) where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
{
|
{
|
||||||
assert_compatible_spmm_dims!(c, a, b);
|
assert_compatible_spmm_dims!(c, a, b);
|
||||||
// Need to interpret matrix as transposed since the spmm_cs_dense function assumes CSR layout
|
// Need to interpret matrix as transposed since the spmm_cs_dense function assumes CSR layout
|
||||||
|
@ -46,19 +48,19 @@ fn spmm_csc_dense_<T>(beta: T,
|
||||||
/// # Panics
|
/// # Panics
|
||||||
///
|
///
|
||||||
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||||
pub fn spadd_csc_prealloc<T>(beta: T,
|
pub fn spadd_csc_prealloc<T>(
|
||||||
c: &mut CscMatrix<T>,
|
beta: T,
|
||||||
alpha: T,
|
c: &mut CscMatrix<T>,
|
||||||
a: Op<&CscMatrix<T>>)
|
alpha: T,
|
||||||
-> Result<(), OperationError>
|
a: Op<&CscMatrix<T>>,
|
||||||
where
|
) -> Result<(), OperationError>
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
{
|
{
|
||||||
assert_compatible_spadd_dims!(c, a);
|
assert_compatible_spadd_dims!(c, a);
|
||||||
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
|
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
|
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
|
@ -74,10 +76,10 @@ pub fn spmm_csc_prealloc<T>(
|
||||||
c: &mut CscMatrix<T>,
|
c: &mut CscMatrix<T>,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
a: Op<&CscMatrix<T>>,
|
a: Op<&CscMatrix<T>>,
|
||||||
b: Op<&CscMatrix<T>>)
|
b: Op<&CscMatrix<T>>,
|
||||||
-> Result<(), OperationError>
|
) -> Result<(), OperationError>
|
||||||
where
|
where
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
{
|
{
|
||||||
assert_compatible_spmm_dims!(c, a, b);
|
assert_compatible_spmm_dims!(c, a, b);
|
||||||
|
|
||||||
|
@ -87,7 +89,7 @@ pub fn spmm_csc_prealloc<T>(
|
||||||
(NoOp(ref a), NoOp(ref b)) => {
|
(NoOp(ref a), NoOp(ref b)) => {
|
||||||
// Note: We have to reverse the order for CSC matrices
|
// Note: We have to reverse the order for CSC matrices
|
||||||
spmm_cs_prealloc(beta, &mut c.cs, alpha, &b.cs, &a.cs)
|
spmm_cs_prealloc(beta, &mut c.cs, alpha, &b.cs, &a.cs)
|
||||||
},
|
}
|
||||||
_ => {
|
_ => {
|
||||||
// Currently we handle transposition by explicitly precomputing transposed matrices
|
// Currently we handle transposition by explicitly precomputing transposed matrices
|
||||||
// and calling the operation again without transposition
|
// and calling the operation again without transposition
|
||||||
|
@ -99,7 +101,9 @@ pub fn spmm_csc_prealloc<T>(
|
||||||
(NoOp(_), NoOp(_)) => unreachable!(),
|
(NoOp(_), NoOp(_)) => unreachable!(),
|
||||||
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
|
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
|
||||||
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
|
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
|
||||||
(Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose()))
|
(Transpose(ref a), Transpose(ref b)) => {
|
||||||
|
(Owned(a.transpose()), Owned(b.transpose()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -121,13 +125,20 @@ pub fn spmm_csc_prealloc<T>(
|
||||||
/// Panics if `L` is not square, or if `L` and `B` are not dimensionally compatible.
|
/// Panics if `L` is not square, or if `L` and `B` are not dimensionally compatible.
|
||||||
pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
|
pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
|
||||||
l: Op<&CscMatrix<T>>,
|
l: Op<&CscMatrix<T>>,
|
||||||
b: impl Into<DMatrixSliceMut<'a, T>>)
|
b: impl Into<DMatrixSliceMut<'a, T>>,
|
||||||
-> Result<(), OperationError>
|
) -> Result<(), OperationError> {
|
||||||
{
|
|
||||||
let b = b.into();
|
let b = b.into();
|
||||||
let l_matrix = l.into_inner();
|
let l_matrix = l.into_inner();
|
||||||
assert_eq!(l_matrix.nrows(), l_matrix.ncols(), "Matrix must be square for triangular solve.");
|
assert_eq!(
|
||||||
assert_eq!(l_matrix.nrows(), b.nrows(), "Dimension mismatch in sparse lower triangular solver.");
|
l_matrix.nrows(),
|
||||||
|
l_matrix.ncols(),
|
||||||
|
"Matrix must be square for triangular solve."
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
l_matrix.nrows(),
|
||||||
|
b.nrows(),
|
||||||
|
"Dimension mismatch in sparse lower triangular solver."
|
||||||
|
);
|
||||||
match l {
|
match l {
|
||||||
Op::NoOp(a) => spsolve_csc_lower_triangular_no_transpose(a, b),
|
Op::NoOp(a) => spsolve_csc_lower_triangular_no_transpose(a, b),
|
||||||
Op::Transpose(a) => spsolve_csc_lower_triangular_transpose(a, b),
|
Op::Transpose(a) => spsolve_csc_lower_triangular_transpose(a, b),
|
||||||
|
@ -136,16 +147,15 @@ pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
|
||||||
|
|
||||||
fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
|
fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
|
||||||
l: &CscMatrix<T>,
|
l: &CscMatrix<T>,
|
||||||
b: DMatrixSliceMut<T>)
|
b: DMatrixSliceMut<T>,
|
||||||
-> Result<(), OperationError>
|
) -> Result<(), OperationError> {
|
||||||
{
|
|
||||||
let mut x = b;
|
let mut x = b;
|
||||||
|
|
||||||
// Solve column-by-column
|
// Solve column-by-column
|
||||||
for j in 0 .. x.ncols() {
|
for j in 0..x.ncols() {
|
||||||
let mut x_col_j = x.column_mut(j);
|
let mut x_col_j = x.column_mut(j);
|
||||||
|
|
||||||
for k in 0 .. l.ncols() {
|
for k in 0..l.ncols() {
|
||||||
let l_col_k = l.col(k);
|
let l_col_k = l.col(k);
|
||||||
|
|
||||||
// Skip entries above the diagonal
|
// Skip entries above the diagonal
|
||||||
|
@ -163,8 +173,8 @@ fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
|
||||||
// Copy value after updating (so we don't run into the borrow checker)
|
// Copy value after updating (so we don't run into the borrow checker)
|
||||||
let x_kj = x_col_j[k];
|
let x_kj = x_col_j[k];
|
||||||
|
|
||||||
let row_indices = &l_col_k.row_indices()[(diag_csc_index + 1) ..];
|
let row_indices = &l_col_k.row_indices()[(diag_csc_index + 1)..];
|
||||||
let l_values = &l_col_k.values()[(diag_csc_index + 1) ..];
|
let l_values = &l_col_k.values()[(diag_csc_index + 1)..];
|
||||||
|
|
||||||
// Note: The remaining entries are below the diagonal
|
// Note: The remaining entries are below the diagonal
|
||||||
for (&i, l_ik) in row_indices.iter().zip(l_values) {
|
for (&i, l_ik) in row_indices.iter().zip(l_values) {
|
||||||
|
@ -187,24 +197,26 @@ fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
|
||||||
|
|
||||||
fn spsolve_encountered_zero_diagonal() -> Result<(), OperationError> {
|
fn spsolve_encountered_zero_diagonal() -> Result<(), OperationError> {
|
||||||
let message = "Matrix contains at least one diagonal entry that is zero.";
|
let message = "Matrix contains at least one diagonal entry that is zero.";
|
||||||
Err(OperationError::from_kind_and_message(OperationErrorKind::Singular, String::from(message)))
|
Err(OperationError::from_kind_and_message(
|
||||||
|
OperationErrorKind::Singular,
|
||||||
|
String::from(message),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spsolve_csc_lower_triangular_transpose<T: RealField>(
|
fn spsolve_csc_lower_triangular_transpose<T: RealField>(
|
||||||
l: &CscMatrix<T>,
|
l: &CscMatrix<T>,
|
||||||
b: DMatrixSliceMut<T>)
|
b: DMatrixSliceMut<T>,
|
||||||
-> Result<(), OperationError>
|
) -> Result<(), OperationError> {
|
||||||
{
|
|
||||||
let mut x = b;
|
let mut x = b;
|
||||||
|
|
||||||
// Solve column-by-column
|
// Solve column-by-column
|
||||||
for j in 0 .. x.ncols() {
|
for j in 0..x.ncols() {
|
||||||
let mut x_col_j = x.column_mut(j);
|
let mut x_col_j = x.column_mut(j);
|
||||||
|
|
||||||
// Due to the transposition, we're essentially solving an upper triangular system,
|
// Due to the transposition, we're essentially solving an upper triangular system,
|
||||||
// and the columns in our matrix become rows
|
// and the columns in our matrix become rows
|
||||||
|
|
||||||
for i in (0 .. l.ncols()).rev() {
|
for i in (0..l.ncols()).rev() {
|
||||||
let l_col_i = l.col(i);
|
let l_col_i = l.col(i);
|
||||||
|
|
||||||
// Skip entries above the diagonal
|
// Skip entries above the diagonal
|
||||||
|
@ -220,8 +232,8 @@ fn spsolve_csc_lower_triangular_transpose<T: RealField>(
|
||||||
// Copy value after updating (so we don't run into the borrow checker)
|
// Copy value after updating (so we don't run into the borrow checker)
|
||||||
let mut x_ii = x_col_j[i];
|
let mut x_ii = x_col_j[i];
|
||||||
|
|
||||||
let row_indices = &l_col_i.row_indices()[(diag_csc_index + 1) ..];
|
let row_indices = &l_col_i.row_indices()[(diag_csc_index + 1)..];
|
||||||
let a_values = &l_col_i.values()[(diag_csc_index + 1) ..];
|
let a_values = &l_col_i.values()[(diag_csc_index + 1)..];
|
||||||
|
|
||||||
// Note: The remaining entries are below the diagonal
|
// Note: The remaining entries are below the diagonal
|
||||||
for (&k, &l_ki) in row_indices.iter().zip(a_values) {
|
for (&k, &l_ki) in row_indices.iter().zip(a_values) {
|
||||||
|
@ -240,4 +252,4 @@ fn spsolve_csc_lower_triangular_transpose<T: RealField>(
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,31 +1,33 @@
|
||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
use crate::ops::{Op};
|
use crate::ops::serial::cs::{spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc};
|
||||||
use crate::ops::serial::{OperationError};
|
use crate::ops::serial::OperationError;
|
||||||
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
|
use crate::ops::Op;
|
||||||
use num_traits::{Zero, One};
|
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
|
||||||
|
use num_traits::{One, Zero};
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc};
|
|
||||||
|
|
||||||
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
|
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
|
||||||
pub fn spmm_csr_dense<'a, T>(beta: T,
|
pub fn spmm_csr_dense<'a, T>(
|
||||||
c: impl Into<DMatrixSliceMut<'a, T>>,
|
beta: T,
|
||||||
alpha: T,
|
c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||||
a: Op<&CsrMatrix<T>>,
|
alpha: T,
|
||||||
b: Op<impl Into<DMatrixSlice<'a, T>>>)
|
a: Op<&CsrMatrix<T>>,
|
||||||
where
|
b: Op<impl Into<DMatrixSlice<'a, T>>>,
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
) where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
{
|
{
|
||||||
let b = b.convert();
|
let b = b.convert();
|
||||||
spmm_csr_dense_(beta, c.into(), alpha, a, b)
|
spmm_csr_dense_(beta, c.into(), alpha, a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spmm_csr_dense_<T>(beta: T,
|
fn spmm_csr_dense_<T>(
|
||||||
c: DMatrixSliceMut<T>,
|
beta: T,
|
||||||
alpha: T,
|
c: DMatrixSliceMut<T>,
|
||||||
a: Op<&CsrMatrix<T>>,
|
alpha: T,
|
||||||
b: Op<DMatrixSlice<T>>)
|
a: Op<&CsrMatrix<T>>,
|
||||||
where
|
b: Op<DMatrixSlice<T>>,
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
) where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
{
|
{
|
||||||
assert_compatible_spmm_dims!(c, a, b);
|
assert_compatible_spmm_dims!(c, a, b);
|
||||||
spmm_cs_dense(beta, c, alpha, a.map_same_op(|a| &a.cs), b)
|
spmm_cs_dense(beta, c, alpha, a.map_same_op(|a| &a.cs), b)
|
||||||
|
@ -41,13 +43,14 @@ where
|
||||||
/// # Panics
|
/// # Panics
|
||||||
///
|
///
|
||||||
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||||
pub fn spadd_csr_prealloc<T>(beta: T,
|
pub fn spadd_csr_prealloc<T>(
|
||||||
c: &mut CsrMatrix<T>,
|
beta: T,
|
||||||
alpha: T,
|
c: &mut CsrMatrix<T>,
|
||||||
a: Op<&CsrMatrix<T>>)
|
alpha: T,
|
||||||
-> Result<(), OperationError>
|
a: Op<&CsrMatrix<T>>,
|
||||||
|
) -> Result<(), OperationError>
|
||||||
where
|
where
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
{
|
{
|
||||||
assert_compatible_spadd_dims!(c, a);
|
assert_compatible_spadd_dims!(c, a);
|
||||||
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
|
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
|
||||||
|
@ -67,19 +70,17 @@ pub fn spmm_csr_prealloc<T>(
|
||||||
c: &mut CsrMatrix<T>,
|
c: &mut CsrMatrix<T>,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
a: Op<&CsrMatrix<T>>,
|
a: Op<&CsrMatrix<T>>,
|
||||||
b: Op<&CsrMatrix<T>>)
|
b: Op<&CsrMatrix<T>>,
|
||||||
-> Result<(), OperationError>
|
) -> Result<(), OperationError>
|
||||||
where
|
where
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||||
{
|
{
|
||||||
assert_compatible_spmm_dims!(c, a, b);
|
assert_compatible_spmm_dims!(c, a, b);
|
||||||
|
|
||||||
use Op::{NoOp, Transpose};
|
use Op::{NoOp, Transpose};
|
||||||
|
|
||||||
match (&a, &b) {
|
match (&a, &b) {
|
||||||
(NoOp(ref a), NoOp(ref b)) => {
|
(NoOp(ref a), NoOp(ref b)) => spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs),
|
||||||
spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs)
|
|
||||||
},
|
|
||||||
_ => {
|
_ => {
|
||||||
// Currently we handle transposition by explicitly precomputing transposed matrices
|
// Currently we handle transposition by explicitly precomputing transposed matrices
|
||||||
// and calling the operation again without transposition
|
// and calling the operation again without transposition
|
||||||
|
@ -93,7 +94,9 @@ where
|
||||||
(NoOp(_), NoOp(_)) => unreachable!(),
|
(NoOp(_), NoOp(_)) => unreachable!(),
|
||||||
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
|
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
|
||||||
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
|
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
|
||||||
(Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose()))
|
(Transpose(ref a), Transpose(ref b)) => {
|
||||||
|
(Owned(a.transpose()), Owned(b.transpose()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -101,4 +104,3 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,33 +10,31 @@
|
||||||
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
macro_rules! assert_compatible_spmm_dims {
|
macro_rules! assert_compatible_spmm_dims {
|
||||||
($c:expr, $a:expr, $b:expr) => {
|
($c:expr, $a:expr, $b:expr) => {{
|
||||||
{
|
use crate::ops::Op::{NoOp, Transpose};
|
||||||
use crate::ops::Op::{NoOp, Transpose};
|
match (&$a, &$b) {
|
||||||
match (&$a, &$b) {
|
(NoOp(ref a), NoOp(ref b)) => {
|
||||||
(NoOp(ref a), NoOp(ref b)) => {
|
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||||
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
|
||||||
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
|
assert_eq!(a.ncols(), b.nrows(), "A.ncols() != B.nrows()");
|
||||||
assert_eq!(a.ncols(), b.nrows(), "A.ncols() != B.nrows()");
|
}
|
||||||
},
|
(Transpose(ref a), NoOp(ref b)) => {
|
||||||
(Transpose(ref a), NoOp(ref b)) => {
|
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||||
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
|
||||||
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
|
assert_eq!(a.nrows(), b.nrows(), "A.nrows() != B.nrows()");
|
||||||
assert_eq!(a.nrows(), b.nrows(), "A.nrows() != B.nrows()");
|
}
|
||||||
},
|
(NoOp(ref a), Transpose(ref b)) => {
|
||||||
(NoOp(ref a), Transpose(ref b)) => {
|
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||||
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
|
||||||
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
|
assert_eq!(a.ncols(), b.ncols(), "A.ncols() != B.ncols()");
|
||||||
assert_eq!(a.ncols(), b.ncols(), "A.ncols() != B.ncols()");
|
}
|
||||||
},
|
(Transpose(ref a), Transpose(ref b)) => {
|
||||||
(Transpose(ref a), Transpose(ref b)) => {
|
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||||
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
|
||||||
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
|
assert_eq!(a.nrows(), b.ncols(), "A.nrows() != B.ncols()");
|
||||||
assert_eq!(a.nrows(), b.ncols(), "A.nrows() != B.ncols()");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}};
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
|
@ -47,32 +45,31 @@ macro_rules! assert_compatible_spadd_dims {
|
||||||
Op::NoOp(a) => {
|
Op::NoOp(a) => {
|
||||||
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||||
assert_eq!($c.ncols(), a.ncols(), "C.ncols() != A.ncols()");
|
assert_eq!($c.ncols(), a.ncols(), "C.ncols() != A.ncols()");
|
||||||
},
|
}
|
||||||
Op::Transpose(a) => {
|
Op::Transpose(a) => {
|
||||||
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||||
assert_eq!($c.ncols(), a.nrows(), "C.ncols() != A.nrows()");
|
assert_eq!($c.ncols(), a.nrows(), "C.ncols() != A.nrows()");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mod cs;
|
||||||
mod csc;
|
mod csc;
|
||||||
mod csr;
|
mod csr;
|
||||||
mod pattern;
|
mod pattern;
|
||||||
mod cs;
|
|
||||||
|
|
||||||
pub use csc::*;
|
pub use csc::*;
|
||||||
pub use csr::*;
|
pub use csr::*;
|
||||||
pub use pattern::*;
|
pub use pattern::*;
|
||||||
use std::fmt::Formatter;
|
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
use std::fmt::Formatter;
|
||||||
|
|
||||||
/// A description of the error that occurred during an arithmetic operation.
|
/// A description of the error that occurred during an arithmetic operation.
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct OperationError {
|
pub struct OperationError {
|
||||||
error_kind: OperationErrorKind,
|
error_kind: OperationErrorKind,
|
||||||
message: String
|
message: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The different kinds of operation errors that may occur.
|
/// The different kinds of operation errors that may occur.
|
||||||
|
@ -92,7 +89,10 @@ pub enum OperationErrorKind {
|
||||||
|
|
||||||
impl OperationError {
|
impl OperationError {
|
||||||
fn from_kind_and_message(error_type: OperationErrorKind, message: String) -> Self {
|
fn from_kind_and_message(error_type: OperationErrorKind, message: String) -> Self {
|
||||||
Self { error_kind: error_type, message }
|
Self {
|
||||||
|
error_kind: error_type,
|
||||||
|
message,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The operation error kind.
|
/// The operation error kind.
|
||||||
|
@ -110,11 +110,15 @@ impl fmt::Display for OperationError {
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||||
write!(f, "Sparse matrix operation error: ")?;
|
write!(f, "Sparse matrix operation error: ")?;
|
||||||
match self.kind() {
|
match self.kind() {
|
||||||
OperationErrorKind::InvalidPattern => { write!(f, "InvalidPattern")?; }
|
OperationErrorKind::InvalidPattern => {
|
||||||
OperationErrorKind::Singular => { write!(f, "Singular")?; }
|
write!(f, "InvalidPattern")?;
|
||||||
|
}
|
||||||
|
OperationErrorKind::Singular => {
|
||||||
|
write!(f, "Singular")?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
write!(f, " Message: {}", self.message)
|
write!(f, " Message: {}", self.message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::error::Error for OperationError {}
|
impl std::error::Error for OperationError {}
|
||||||
|
|
|
@ -12,11 +12,17 @@ use std::iter;
|
||||||
/// # Panics
|
/// # Panics
|
||||||
///
|
///
|
||||||
/// Panics if the patterns do not have the same major and minor dimensions.
|
/// Panics if the patterns do not have the same major and minor dimensions.
|
||||||
pub fn spadd_pattern(a: &SparsityPattern,
|
pub fn spadd_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||||
b: &SparsityPattern) -> SparsityPattern
|
assert_eq!(
|
||||||
{
|
a.major_dim(),
|
||||||
assert_eq!(a.major_dim(), b.major_dim(), "Patterns must have identical major dimensions.");
|
b.major_dim(),
|
||||||
assert_eq!(a.minor_dim(), b.minor_dim(), "Patterns must have identical minor dimensions.");
|
"Patterns must have identical major dimensions."
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
a.minor_dim(),
|
||||||
|
b.minor_dim(),
|
||||||
|
"Patterns must have identical minor dimensions."
|
||||||
|
);
|
||||||
|
|
||||||
let mut offsets = Vec::new();
|
let mut offsets = Vec::new();
|
||||||
let mut indices = Vec::new();
|
let mut indices = Vec::new();
|
||||||
|
@ -25,7 +31,7 @@ pub fn spadd_pattern(a: &SparsityPattern,
|
||||||
|
|
||||||
offsets.push(0);
|
offsets.push(0);
|
||||||
|
|
||||||
for lane_idx in 0 .. a.major_dim() {
|
for lane_idx in 0..a.major_dim() {
|
||||||
let lane_a = a.lane(lane_idx);
|
let lane_a = a.lane(lane_idx);
|
||||||
let lane_b = b.lane(lane_idx);
|
let lane_b = b.lane(lane_idx);
|
||||||
indices.extend(iterate_union(lane_a, lane_b));
|
indices.extend(iterate_union(lane_a, lane_b));
|
||||||
|
@ -33,8 +39,7 @@ pub fn spadd_pattern(a: &SparsityPattern,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Consider circumventing format checks? (requires unsafe, should benchmark first)
|
// TODO: Consider circumventing format checks? (requires unsafe, should benchmark first)
|
||||||
SparsityPattern::try_from_offsets_and_indices(
|
SparsityPattern::try_from_offsets_and_indices(a.major_dim(), a.minor_dim(), offsets, indices)
|
||||||
a.major_dim(), a.minor_dim(), offsets, indices)
|
|
||||||
.expect("Internal error: Pattern must be valid by definition")
|
.expect("Internal error: Pattern must be valid by definition")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,7 +71,11 @@ pub fn spmm_csc_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPat
|
||||||
/// Panics if the patterns, when interpreted as CSR patterns, are not compatible for
|
/// Panics if the patterns, when interpreted as CSR patterns, are not compatible for
|
||||||
/// matrix multiplication.
|
/// matrix multiplication.
|
||||||
pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||||
assert_eq!(a.minor_dim(), b.major_dim(), "a and b must have compatible dimensions");
|
assert_eq!(
|
||||||
|
a.minor_dim(),
|
||||||
|
b.major_dim(),
|
||||||
|
"a and b must have compatible dimensions"
|
||||||
|
);
|
||||||
|
|
||||||
let mut offsets = Vec::new();
|
let mut offsets = Vec::new();
|
||||||
let mut indices = Vec::new();
|
let mut indices = Vec::new();
|
||||||
|
@ -78,7 +87,7 @@ pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPat
|
||||||
// (would cut memory use to 1/8, which might help reduce cache misses)
|
// (would cut memory use to 1/8, which might help reduce cache misses)
|
||||||
let mut visited = vec![false; b.minor_dim()];
|
let mut visited = vec![false; b.minor_dim()];
|
||||||
|
|
||||||
for i in 0 .. a.major_dim() {
|
for i in 0..a.major_dim() {
|
||||||
let a_lane_i = a.lane(i);
|
let a_lane_i = a.lane(i);
|
||||||
let c_lane_i_offset = *offsets.last().unwrap();
|
let c_lane_i_offset = *offsets.last().unwrap();
|
||||||
for &k in a_lane_i {
|
for &k in a_lane_i {
|
||||||
|
@ -93,7 +102,7 @@ pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPat
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let c_lane_i = &mut indices[c_lane_i_offset ..];
|
let c_lane_i = &mut indices[c_lane_i_offset..];
|
||||||
c_lane_i.sort_unstable();
|
c_lane_i.sort_unstable();
|
||||||
|
|
||||||
// Reset visits so that visited[j] == false for all j for the next major lane
|
// Reset visits so that visited[j] == false for all j for the next major lane
|
||||||
|
@ -110,21 +119,23 @@ pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPat
|
||||||
|
|
||||||
/// Iterate over the union of the two sets represented by sorted slices
|
/// Iterate over the union of the two sets represented by sorted slices
|
||||||
/// (with unique elements)
|
/// (with unique elements)
|
||||||
fn iterate_union<'a>(mut sorted_a: &'a [usize],
|
fn iterate_union<'a>(
|
||||||
mut sorted_b: &'a [usize]) -> impl Iterator<Item=usize> + 'a {
|
mut sorted_a: &'a [usize],
|
||||||
|
mut sorted_b: &'a [usize],
|
||||||
|
) -> impl Iterator<Item = usize> + 'a {
|
||||||
iter::from_fn(move || {
|
iter::from_fn(move || {
|
||||||
if let (Some(a_item), Some(b_item)) = (sorted_a.first(), sorted_b.first()) {
|
if let (Some(a_item), Some(b_item)) = (sorted_a.first(), sorted_b.first()) {
|
||||||
let item = if a_item < b_item {
|
let item = if a_item < b_item {
|
||||||
sorted_a = &sorted_a[1 ..];
|
sorted_a = &sorted_a[1..];
|
||||||
a_item
|
a_item
|
||||||
} else if b_item < a_item {
|
} else if b_item < a_item {
|
||||||
sorted_b = &sorted_b[1 ..];
|
sorted_b = &sorted_b[1..];
|
||||||
b_item
|
b_item
|
||||||
} else {
|
} else {
|
||||||
// Both lists contain the same element, advance both slices to avoid
|
// Both lists contain the same element, advance both slices to avoid
|
||||||
// duplicate entries in the result
|
// duplicate entries in the result
|
||||||
sorted_a = &sorted_a[1 ..];
|
sorted_a = &sorted_a[1..];
|
||||||
sorted_b = &sorted_b[1 ..];
|
sorted_b = &sorted_b[1..];
|
||||||
a_item
|
a_item
|
||||||
};
|
};
|
||||||
Some(*item)
|
Some(*item)
|
||||||
|
@ -138,4 +149,4 @@ fn iterate_union<'a>(mut sorted_a: &'a [usize],
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
//! Sparsity patterns for CSR and CSC matrices.
|
//! Sparsity patterns for CSR and CSC matrices.
|
||||||
use crate::SparseFormatError;
|
|
||||||
use std::fmt;
|
|
||||||
use std::error::Error;
|
|
||||||
use crate::cs::transpose_cs;
|
use crate::cs::transpose_cs;
|
||||||
|
use crate::SparseFormatError;
|
||||||
|
use std::error::Error;
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
/// A representation of the sparsity pattern of a CSR or CSC matrix.
|
/// A representation of the sparsity pattern of a CSR or CSC matrix.
|
||||||
///
|
///
|
||||||
|
@ -137,7 +137,7 @@ impl SparsityPattern {
|
||||||
// minor indices within a lane are sorted, unique. In addition, each minor index
|
// minor indices within a lane are sorted, unique. In addition, each minor index
|
||||||
// must be in bounds with respect to the minor dimension.
|
// must be in bounds with respect to the minor dimension.
|
||||||
{
|
{
|
||||||
for lane_idx in 0 .. major_dim {
|
for lane_idx in 0..major_dim {
|
||||||
let range_start = major_offsets[lane_idx];
|
let range_start = major_offsets[lane_idx];
|
||||||
let range_end = major_offsets[lane_idx + 1];
|
let range_end = major_offsets[lane_idx + 1];
|
||||||
|
|
||||||
|
@ -146,7 +146,7 @@ impl SparsityPattern {
|
||||||
return Err(NonmonotonicOffsets);
|
return Err(NonmonotonicOffsets);
|
||||||
}
|
}
|
||||||
|
|
||||||
let minor_indices = &minor_indices[range_start .. range_end];
|
let minor_indices = &minor_indices[range_start..range_end];
|
||||||
|
|
||||||
// We test for in-bounds, uniqueness and monotonicity at the same time
|
// We test for in-bounds, uniqueness and monotonicity at the same time
|
||||||
// to ensure that we only visit each minor index once
|
// to ensure that we only visit each minor index once
|
||||||
|
@ -232,17 +232,20 @@ impl SparsityPattern {
|
||||||
// By using unit () values, we can use the same routines as for CSR/CSC matrices
|
// By using unit () values, we can use the same routines as for CSR/CSC matrices
|
||||||
let values = vec![(); self.nnz()];
|
let values = vec![(); self.nnz()];
|
||||||
let (new_offsets, new_indices, _) = transpose_cs(
|
let (new_offsets, new_indices, _) = transpose_cs(
|
||||||
self.major_dim(),
|
self.major_dim(),
|
||||||
self.minor_dim(),
|
self.minor_dim(),
|
||||||
self.major_offsets(),
|
self.major_offsets(),
|
||||||
self.minor_indices(),
|
self.minor_indices(),
|
||||||
&values);
|
&values,
|
||||||
|
);
|
||||||
// TODO: Skip checks
|
// TODO: Skip checks
|
||||||
Self::try_from_offsets_and_indices(self.minor_dim(),
|
Self::try_from_offsets_and_indices(
|
||||||
self.major_dim(),
|
self.minor_dim(),
|
||||||
new_offsets,
|
self.major_dim(),
|
||||||
new_indices)
|
new_offsets,
|
||||||
.expect("Internal error: Transpose should never fail.")
|
new_indices,
|
||||||
|
)
|
||||||
|
.expect("Internal error: Transpose should never fail.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -275,22 +278,25 @@ pub enum SparsityPatternFormatError {
|
||||||
|
|
||||||
impl From<SparsityPatternFormatError> for SparseFormatError {
|
impl From<SparsityPatternFormatError> for SparseFormatError {
|
||||||
fn from(err: SparsityPatternFormatError) -> Self {
|
fn from(err: SparsityPatternFormatError) -> Self {
|
||||||
use SparsityPatternFormatError::*;
|
|
||||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
|
||||||
use crate::SparseFormatErrorKind;
|
use crate::SparseFormatErrorKind;
|
||||||
use crate::SparseFormatErrorKind::*;
|
use crate::SparseFormatErrorKind::*;
|
||||||
|
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||||
|
use SparsityPatternFormatError::*;
|
||||||
match err {
|
match err {
|
||||||
InvalidOffsetArrayLength
|
InvalidOffsetArrayLength
|
||||||
| InvalidOffsetFirstLast
|
| InvalidOffsetFirstLast
|
||||||
| NonmonotonicOffsets
|
| NonmonotonicOffsets
|
||||||
| NonmonotonicMinorIndices
|
| NonmonotonicMinorIndices => {
|
||||||
=> SparseFormatError::from_kind_and_error(InvalidStructure, Box::from(err)),
|
SparseFormatError::from_kind_and_error(InvalidStructure, Box::from(err))
|
||||||
MinorIndexOutOfBounds
|
}
|
||||||
=> SparseFormatError::from_kind_and_error(IndexOutOfBounds,
|
MinorIndexOutOfBounds => {
|
||||||
Box::from(err)),
|
SparseFormatError::from_kind_and_error(IndexOutOfBounds, Box::from(err))
|
||||||
PatternDuplicateEntry
|
}
|
||||||
=> SparseFormatError::from_kind_and_error(SparseFormatErrorKind::DuplicateEntry,
|
PatternDuplicateEntry => SparseFormatError::from_kind_and_error(
|
||||||
Box::from(err)),
|
#[allow(unused_qualifications)]
|
||||||
|
SparseFormatErrorKind::DuplicateEntry,
|
||||||
|
Box::from(err),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -300,22 +306,25 @@ impl fmt::Display for SparsityPatternFormatError {
|
||||||
match self {
|
match self {
|
||||||
SparsityPatternFormatError::InvalidOffsetArrayLength => {
|
SparsityPatternFormatError::InvalidOffsetArrayLength => {
|
||||||
write!(f, "Length of offset array is not equal to (major_dim + 1).")
|
write!(f, "Length of offset array is not equal to (major_dim + 1).")
|
||||||
},
|
}
|
||||||
SparsityPatternFormatError::InvalidOffsetFirstLast => {
|
SparsityPatternFormatError::InvalidOffsetFirstLast => {
|
||||||
write!(f, "First or last offset is incompatible with format.")
|
write!(f, "First or last offset is incompatible with format.")
|
||||||
},
|
}
|
||||||
SparsityPatternFormatError::NonmonotonicOffsets => {
|
SparsityPatternFormatError::NonmonotonicOffsets => {
|
||||||
write!(f, "Offsets are not monotonically increasing.")
|
write!(f, "Offsets are not monotonically increasing.")
|
||||||
},
|
}
|
||||||
SparsityPatternFormatError::MinorIndexOutOfBounds => {
|
SparsityPatternFormatError::MinorIndexOutOfBounds => {
|
||||||
write!(f, "A minor index is out of bounds.")
|
write!(f, "A minor index is out of bounds.")
|
||||||
},
|
}
|
||||||
SparsityPatternFormatError::DuplicateEntry => {
|
SparsityPatternFormatError::DuplicateEntry => {
|
||||||
write!(f, "Input data contains duplicate entries.")
|
write!(f, "Input data contains duplicate entries.")
|
||||||
},
|
}
|
||||||
SparsityPatternFormatError::NonmonotonicMinorIndices => {
|
SparsityPatternFormatError::NonmonotonicMinorIndices => {
|
||||||
write!(f, "Minor indices are not monotonically increasing within each lane.")
|
write!(
|
||||||
},
|
f,
|
||||||
|
"Minor indices are not monotonically increasing within each lane."
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -335,12 +344,12 @@ pub struct SparsityPatternIter<'a> {
|
||||||
impl<'a> SparsityPatternIter<'a> {
|
impl<'a> SparsityPatternIter<'a> {
|
||||||
fn from_pattern(pattern: &'a SparsityPattern) -> Self {
|
fn from_pattern(pattern: &'a SparsityPattern) -> Self {
|
||||||
let first_lane_end = pattern.major_offsets().get(1).unwrap_or(&0);
|
let first_lane_end = pattern.major_offsets().get(1).unwrap_or(&0);
|
||||||
let minors_in_first_lane = &pattern.minor_indices()[0 .. *first_lane_end];
|
let minors_in_first_lane = &pattern.minor_indices()[0..*first_lane_end];
|
||||||
Self {
|
Self {
|
||||||
major_offsets: pattern.major_offsets(),
|
major_offsets: pattern.major_offsets(),
|
||||||
minor_indices: pattern.minor_indices(),
|
minor_indices: pattern.minor_indices(),
|
||||||
current_lane_idx: 0,
|
current_lane_idx: 0,
|
||||||
remaining_minors_in_lane: minors_in_first_lane
|
remaining_minors_in_lane: minors_in_first_lane,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -374,12 +383,11 @@ impl<'a> Iterator for SparsityPatternIter<'a> {
|
||||||
let lower = self.major_offsets[self.current_lane_idx];
|
let lower = self.major_offsets[self.current_lane_idx];
|
||||||
let upper = self.major_offsets[self.current_lane_idx + 1];
|
let upper = self.major_offsets[self.current_lane_idx + 1];
|
||||||
if upper > lower {
|
if upper > lower {
|
||||||
self.remaining_minors_in_lane = &self.minor_indices[(lower + 1) .. upper];
|
self.remaining_minors_in_lane = &self.minor_indices[(lower + 1)..upper];
|
||||||
return Some((self.current_lane_idx, self.minor_indices[lower]))
|
return Some((self.current_lane_idx, self.minor_indices[lower]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -11,20 +11,22 @@
|
||||||
mod proptest_patched;
|
mod proptest_patched;
|
||||||
|
|
||||||
use crate::coo::CooMatrix;
|
use crate::coo::CooMatrix;
|
||||||
use proptest::prelude::*;
|
use crate::csc::CscMatrix;
|
||||||
use proptest::collection::{vec, hash_map, btree_set};
|
|
||||||
use nalgebra::{Scalar, Dim};
|
|
||||||
use std::cmp::min;
|
|
||||||
use std::iter::{repeat};
|
|
||||||
use proptest::sample::{Index};
|
|
||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
use crate::pattern::SparsityPattern;
|
use crate::pattern::SparsityPattern;
|
||||||
use crate::csc::CscMatrix;
|
|
||||||
use nalgebra::proptest::DimRange;
|
use nalgebra::proptest::DimRange;
|
||||||
|
use nalgebra::{Dim, Scalar};
|
||||||
|
use proptest::collection::{btree_set, hash_map, vec};
|
||||||
|
use proptest::prelude::*;
|
||||||
|
use proptest::sample::Index;
|
||||||
|
use std::cmp::min;
|
||||||
|
use std::iter::repeat;
|
||||||
|
|
||||||
fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
|
fn dense_row_major_coord_strategy(
|
||||||
-> impl Strategy<Value=Vec<(usize, usize)>>
|
nrows: usize,
|
||||||
{
|
ncols: usize,
|
||||||
|
nnz: usize,
|
||||||
|
) -> impl Strategy<Value = Vec<(usize, usize)>> {
|
||||||
assert!(nnz <= nrows * ncols);
|
assert!(nnz <= nrows * ncols);
|
||||||
let mut booleans = vec![true; nnz];
|
let mut booleans = vec![true; nnz];
|
||||||
booleans.append(&mut vec![false; (nrows * ncols) - nnz]);
|
booleans.append(&mut vec![false; (nrows * ncols) - nnz]);
|
||||||
|
@ -38,33 +40,33 @@ fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
|
||||||
// // Need to shuffle to make sure they are randomly distributed
|
// // Need to shuffle to make sure they are randomly distributed
|
||||||
// .prop_shuffle()
|
// .prop_shuffle()
|
||||||
|
|
||||||
proptest_patched::Shuffle(Just(booleans))
|
proptest_patched::Shuffle(Just(booleans)).prop_map(move |booleans| {
|
||||||
.prop_map(move |booleans| {
|
booleans
|
||||||
booleans
|
.into_iter()
|
||||||
.into_iter()
|
.enumerate()
|
||||||
.enumerate()
|
.filter_map(|(index, is_entry)| {
|
||||||
.filter_map(|(index, is_entry)| {
|
if is_entry {
|
||||||
if is_entry {
|
// Convert linear index to row/col pair
|
||||||
// Convert linear index to row/col pair
|
let i = index / ncols;
|
||||||
let i = index / ncols;
|
let j = index % ncols;
|
||||||
let j = index % ncols;
|
Some((i, j))
|
||||||
Some((i, j))
|
} else {
|
||||||
} else {
|
None
|
||||||
None
|
}
|
||||||
}
|
})
|
||||||
})
|
.collect::<Vec<_>>()
|
||||||
.collect::<Vec<_>>()
|
})
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A strategy for generating `nnz` triplets.
|
/// A strategy for generating `nnz` triplets.
|
||||||
///
|
///
|
||||||
/// This strategy should generally only be used when `nnz` is close to `nrows * ncols`.
|
/// This strategy should generally only be used when `nnz` is close to `nrows * ncols`.
|
||||||
fn dense_triplet_strategy<T>(value_strategy: T,
|
fn dense_triplet_strategy<T>(
|
||||||
nrows: usize,
|
value_strategy: T,
|
||||||
ncols: usize,
|
nrows: usize,
|
||||||
nnz: usize)
|
ncols: usize,
|
||||||
-> impl Strategy<Value=Vec<(usize, usize, T::Value)>>
|
nnz: usize,
|
||||||
|
) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
|
||||||
where
|
where
|
||||||
T: Strategy + Clone + 'static,
|
T: Strategy + Clone + 'static,
|
||||||
T::Value: Scalar,
|
T::Value: Scalar,
|
||||||
|
@ -100,15 +102,14 @@ where
|
||||||
})
|
})
|
||||||
// Assign values to each coordinate pair in order to generate a list of triplets
|
// Assign values to each coordinate pair in order to generate a list of triplets
|
||||||
.prop_flat_map(move |coords| {
|
.prop_flat_map(move |coords| {
|
||||||
vec![value_strategy.clone(); coords.len()]
|
vec![value_strategy.clone(); coords.len()].prop_map(move |values| {
|
||||||
.prop_map(move |values| {
|
coords
|
||||||
coords.clone().into_iter()
|
.clone()
|
||||||
.zip(values)
|
.into_iter()
|
||||||
.map(|((i, j), v)| {
|
.zip(values)
|
||||||
(i, j, v)
|
.map(|((i, j), v)| (i, j, v))
|
||||||
})
|
.collect::<Vec<_>>()
|
||||||
.collect::<Vec<_>>()
|
})
|
||||||
})
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,25 +117,23 @@ where
|
||||||
///
|
///
|
||||||
/// This strategy should generally only be used when `nnz << nrows * ncols`. If `nnz` is too
|
/// This strategy should generally only be used when `nnz << nrows * ncols`. If `nnz` is too
|
||||||
/// close to `nrows * ncols` it may fail due to excessive rejected samples.
|
/// close to `nrows * ncols` it may fail due to excessive rejected samples.
|
||||||
fn sparse_triplet_strategy<T>(value_strategy: T,
|
fn sparse_triplet_strategy<T>(
|
||||||
nrows: usize,
|
value_strategy: T,
|
||||||
ncols: usize,
|
nrows: usize,
|
||||||
nnz: usize)
|
ncols: usize,
|
||||||
-> impl Strategy<Value=Vec<(usize, usize, T::Value)>>
|
nnz: usize,
|
||||||
where
|
) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
|
||||||
T: Strategy + Clone + 'static,
|
where
|
||||||
T::Value: Scalar,
|
T: Strategy + Clone + 'static,
|
||||||
|
T::Value: Scalar,
|
||||||
{
|
{
|
||||||
// Have to handle the zero case: proptest doesn't like empty ranges (i.e. 0 .. 0)
|
// Have to handle the zero case: proptest doesn't like empty ranges (i.e. 0 .. 0)
|
||||||
let row_index_strategy = if nrows > 0 { 0 .. nrows } else { 0 .. 1 };
|
let row_index_strategy = if nrows > 0 { 0..nrows } else { 0..1 };
|
||||||
let col_index_strategy = if ncols > 0 { 0 .. ncols } else { 0 .. 1 };
|
let col_index_strategy = if ncols > 0 { 0..ncols } else { 0..1 };
|
||||||
let coord_strategy = (row_index_strategy, col_index_strategy);
|
let coord_strategy = (row_index_strategy, col_index_strategy);
|
||||||
hash_map(coord_strategy, value_strategy.clone(), nnz)
|
hash_map(coord_strategy, value_strategy.clone(), nnz)
|
||||||
.prop_map(|hash_map| {
|
.prop_map(|hash_map| {
|
||||||
let triplets: Vec<_> = hash_map
|
let triplets: Vec<_> = hash_map.into_iter().map(|((i, j), v)| (i, j, v)).collect();
|
||||||
.into_iter()
|
|
||||||
.map(|((i, j), v)| (i, j, v))
|
|
||||||
.collect();
|
|
||||||
triplets
|
triplets
|
||||||
})
|
})
|
||||||
// Although order in the hash map is unspecified, it's not necessarily *random*
|
// Although order in the hash map is unspecified, it's not necessarily *random*
|
||||||
|
@ -153,36 +152,41 @@ pub fn coo_no_duplicates<T>(
|
||||||
value_strategy: T,
|
value_strategy: T,
|
||||||
rows: impl Into<DimRange>,
|
rows: impl Into<DimRange>,
|
||||||
cols: impl Into<DimRange>,
|
cols: impl Into<DimRange>,
|
||||||
max_nonzeros: usize) -> impl Strategy<Value=CooMatrix<T::Value>>
|
max_nonzeros: usize,
|
||||||
|
) -> impl Strategy<Value = CooMatrix<T::Value>>
|
||||||
where
|
where
|
||||||
T: Strategy + Clone + 'static,
|
T: Strategy + Clone + 'static,
|
||||||
T::Value: Scalar,
|
T::Value: Scalar,
|
||||||
{
|
{
|
||||||
(rows.into().to_range_inclusive(), cols.into().to_range_inclusive())
|
(
|
||||||
|
rows.into().to_range_inclusive(),
|
||||||
|
cols.into().to_range_inclusive(),
|
||||||
|
)
|
||||||
.prop_flat_map(move |(nrows, ncols)| {
|
.prop_flat_map(move |(nrows, ncols)| {
|
||||||
let max_nonzeros = min(max_nonzeros, nrows * ncols);
|
let max_nonzeros = min(max_nonzeros, nrows * ncols);
|
||||||
let size_range = 0 ..= max_nonzeros;
|
let size_range = 0..=max_nonzeros;
|
||||||
let value_strategy = value_strategy.clone();
|
let value_strategy = value_strategy.clone();
|
||||||
|
|
||||||
size_range.prop_flat_map(move |nnz| {
|
size_range
|
||||||
let value_strategy = value_strategy.clone();
|
.prop_flat_map(move |nnz| {
|
||||||
if nnz as f64 > 0.10 * (nrows as f64) * (ncols as f64) {
|
let value_strategy = value_strategy.clone();
|
||||||
// If the number of nnz is sufficiently dense, then use the dense
|
if nnz as f64 > 0.10 * (nrows as f64) * (ncols as f64) {
|
||||||
// sample strategy
|
// If the number of nnz is sufficiently dense, then use the dense
|
||||||
dense_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
|
// sample strategy
|
||||||
} else {
|
dense_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
|
||||||
// Otherwise, use a hash map strategy so that we can get a sparse sampling
|
} else {
|
||||||
// (so that complexity is rather on the order of max_nnz than nrows * ncols)
|
// Otherwise, use a hash map strategy so that we can get a sparse sampling
|
||||||
sparse_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
|
// (so that complexity is rather on the order of max_nnz than nrows * ncols)
|
||||||
}
|
sparse_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
|
||||||
})
|
}
|
||||||
.prop_map(move |triplets| {
|
})
|
||||||
let mut coo = CooMatrix::new(nrows, ncols);
|
.prop_map(move |triplets| {
|
||||||
for (i, j, v) in triplets {
|
let mut coo = CooMatrix::new(nrows, ncols);
|
||||||
coo.push(i, j, v);
|
for (i, j, v) in triplets {
|
||||||
}
|
coo.push(i, j, v);
|
||||||
coo
|
}
|
||||||
})
|
coo
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,21 +202,22 @@ where
|
||||||
/// number of duplicate entries is determined by `max_duplicates`. Note that the matrix might still
|
/// number of duplicate entries is determined by `max_duplicates`. Note that the matrix might still
|
||||||
/// contain explicitly stored zeroes if the value strategy is capable of generating zero values.
|
/// contain explicitly stored zeroes if the value strategy is capable of generating zero values.
|
||||||
pub fn coo_with_duplicates<T>(
|
pub fn coo_with_duplicates<T>(
|
||||||
value_strategy: T,
|
value_strategy: T,
|
||||||
rows: impl Into<DimRange>,
|
rows: impl Into<DimRange>,
|
||||||
cols: impl Into<DimRange>,
|
cols: impl Into<DimRange>,
|
||||||
max_nonzeros: usize,
|
max_nonzeros: usize,
|
||||||
max_duplicates: usize)
|
max_duplicates: usize,
|
||||||
-> impl Strategy<Value=CooMatrix<T::Value>>
|
) -> impl Strategy<Value = CooMatrix<T::Value>>
|
||||||
where
|
where
|
||||||
T: Strategy + Clone + 'static,
|
T: Strategy + Clone + 'static,
|
||||||
T::Value: Scalar,
|
T::Value: Scalar,
|
||||||
{
|
{
|
||||||
let coo_strategy = coo_no_duplicates(value_strategy.clone(), rows, cols, max_nonzeros);
|
let coo_strategy = coo_no_duplicates(value_strategy.clone(), rows, cols, max_nonzeros);
|
||||||
let duplicate_strategy = vec((any::<Index>(), value_strategy.clone()), 0 ..= max_duplicates);
|
let duplicate_strategy = vec((any::<Index>(), value_strategy.clone()), 0..=max_duplicates);
|
||||||
(coo_strategy, duplicate_strategy)
|
(coo_strategy, duplicate_strategy)
|
||||||
.prop_flat_map(|(coo, duplicates)| {
|
.prop_flat_map(|(coo, duplicates)| {
|
||||||
let mut triplets: Vec<(usize, usize, T::Value)> = coo.triplet_iter()
|
let mut triplets: Vec<(usize, usize, T::Value)> = coo
|
||||||
|
.triplet_iter()
|
||||||
.map(|(i, j, v)| (i, j, v.clone()))
|
.map(|(i, j, v)| (i, j, v.clone()))
|
||||||
.collect();
|
.collect();
|
||||||
if !triplets.is_empty() {
|
if !triplets.is_empty() {
|
||||||
|
@ -238,9 +243,13 @@ where
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sparsity_pattern_from_row_major_coords<I>(nmajor: usize, nminor: usize, coords: I) -> SparsityPattern
|
fn sparsity_pattern_from_row_major_coords<I>(
|
||||||
|
nmajor: usize,
|
||||||
|
nminor: usize,
|
||||||
|
coords: I,
|
||||||
|
) -> SparsityPattern
|
||||||
where
|
where
|
||||||
I: Iterator<Item=(usize, usize)> + ExactSizeIterator,
|
I: Iterator<Item = (usize, usize)> + ExactSizeIterator,
|
||||||
{
|
{
|
||||||
let mut minors = Vec::with_capacity(coords.len());
|
let mut minors = Vec::with_capacity(coords.len());
|
||||||
let mut offsets = Vec::with_capacity(nmajor + 1);
|
let mut offsets = Vec::with_capacity(nmajor + 1);
|
||||||
|
@ -248,8 +257,11 @@ where
|
||||||
offsets.push(0);
|
offsets.push(0);
|
||||||
for (idx, (i, j)) in coords.enumerate() {
|
for (idx, (i, j)) in coords.enumerate() {
|
||||||
assert!(i >= current_major);
|
assert!(i >= current_major);
|
||||||
assert!(i < nmajor && j < nminor, "Generated coords are out of bounds");
|
assert!(
|
||||||
while current_major < i{
|
i < nmajor && j < nminor,
|
||||||
|
"Generated coords are out of bounds"
|
||||||
|
);
|
||||||
|
while current_major < i {
|
||||||
offsets.push(idx);
|
offsets.push(idx);
|
||||||
current_major += 1;
|
current_major += 1;
|
||||||
}
|
}
|
||||||
|
@ -264,10 +276,7 @@ where
|
||||||
assert_eq!(offsets.first().unwrap(), &0);
|
assert_eq!(offsets.first().unwrap(), &0);
|
||||||
assert_eq!(offsets.len(), nmajor + 1);
|
assert_eq!(offsets.len(), nmajor + 1);
|
||||||
|
|
||||||
SparsityPattern::try_from_offsets_and_indices(nmajor,
|
SparsityPattern::try_from_offsets_and_indices(nmajor, nminor, offsets, minors)
|
||||||
nminor,
|
|
||||||
offsets,
|
|
||||||
minors)
|
|
||||||
.expect("Internal error: Generated sparsity pattern is invalid")
|
.expect("Internal error: Generated sparsity pattern is invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -275,14 +284,17 @@ where
|
||||||
pub fn sparsity_pattern(
|
pub fn sparsity_pattern(
|
||||||
major_lanes: impl Into<DimRange>,
|
major_lanes: impl Into<DimRange>,
|
||||||
minor_lanes: impl Into<DimRange>,
|
minor_lanes: impl Into<DimRange>,
|
||||||
max_nonzeros: usize)
|
max_nonzeros: usize,
|
||||||
-> impl Strategy<Value=SparsityPattern>
|
) -> impl Strategy<Value = SparsityPattern> {
|
||||||
{
|
(
|
||||||
(major_lanes.into().to_range_inclusive(), minor_lanes.into().to_range_inclusive())
|
major_lanes.into().to_range_inclusive(),
|
||||||
|
minor_lanes.into().to_range_inclusive(),
|
||||||
|
)
|
||||||
.prop_flat_map(move |(nmajor, nminor)| {
|
.prop_flat_map(move |(nmajor, nminor)| {
|
||||||
let max_nonzeros = min(nmajor * nminor, max_nonzeros);
|
let max_nonzeros = min(nmajor * nminor, max_nonzeros);
|
||||||
(Just(nmajor), Just(nminor), 0 ..= max_nonzeros)
|
(Just(nmajor), Just(nminor), 0..=max_nonzeros)
|
||||||
}).prop_flat_map(move |(nmajor, nminor, nnz)| {
|
})
|
||||||
|
.prop_flat_map(move |(nmajor, nminor, nnz)| {
|
||||||
if 10 * nnz < nmajor * nminor {
|
if 10 * nnz < nmajor * nminor {
|
||||||
// If nnz is small compared to a dense matrix, then use a sparse sampling strategy
|
// If nnz is small compared to a dense matrix, then use a sparse sampling strategy
|
||||||
btree_set((0..nmajor, 0..nminor), nnz)
|
btree_set((0..nmajor, 0..nminor), nnz)
|
||||||
|
@ -297,55 +309,66 @@ pub fn sparsity_pattern(
|
||||||
.prop_map(move |coords| {
|
.prop_map(move |coords| {
|
||||||
let coords = coords.into_iter();
|
let coords = coords.into_iter();
|
||||||
sparsity_pattern_from_row_major_coords(nmajor, nminor, coords)
|
sparsity_pattern_from_row_major_coords(nmajor, nminor, coords)
|
||||||
}).boxed()
|
})
|
||||||
|
.boxed()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A strategy for generating CSR matrices.
|
/// A strategy for generating CSR matrices.
|
||||||
pub fn csr<T>(value_strategy: T,
|
pub fn csr<T>(
|
||||||
rows: impl Into<DimRange>,
|
value_strategy: T,
|
||||||
cols: impl Into<DimRange>,
|
rows: impl Into<DimRange>,
|
||||||
max_nonzeros: usize)
|
cols: impl Into<DimRange>,
|
||||||
-> impl Strategy<Value=CsrMatrix<T::Value>>
|
max_nonzeros: usize,
|
||||||
|
) -> impl Strategy<Value = CsrMatrix<T::Value>>
|
||||||
where
|
where
|
||||||
T: Strategy + Clone + 'static,
|
T: Strategy + Clone + 'static,
|
||||||
T::Value: Scalar,
|
T::Value: Scalar,
|
||||||
{
|
{
|
||||||
let rows = rows.into();
|
let rows = rows.into();
|
||||||
let cols = cols.into();
|
let cols = cols.into();
|
||||||
sparsity_pattern(rows.lower_bound().value() ..= rows.upper_bound().value(), cols.lower_bound().value() ..= cols.upper_bound().value(), max_nonzeros)
|
sparsity_pattern(
|
||||||
.prop_flat_map(move |pattern| {
|
rows.lower_bound().value()..=rows.upper_bound().value(),
|
||||||
let nnz = pattern.nnz();
|
cols.lower_bound().value()..=cols.upper_bound().value(),
|
||||||
let values = vec![value_strategy.clone(); nnz];
|
max_nonzeros,
|
||||||
(Just(pattern), values)
|
)
|
||||||
})
|
.prop_flat_map(move |pattern| {
|
||||||
.prop_map(|(pattern, values)| {
|
let nnz = pattern.nnz();
|
||||||
CsrMatrix::try_from_pattern_and_values(pattern, values)
|
let values = vec![value_strategy.clone(); nnz];
|
||||||
.expect("Internal error: Generated CsrMatrix is invalid")
|
(Just(pattern), values)
|
||||||
})
|
})
|
||||||
|
.prop_map(|(pattern, values)| {
|
||||||
|
CsrMatrix::try_from_pattern_and_values(pattern, values)
|
||||||
|
.expect("Internal error: Generated CsrMatrix is invalid")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A strategy for generating CSC matrices.
|
/// A strategy for generating CSC matrices.
|
||||||
pub fn csc<T>(value_strategy: T,
|
pub fn csc<T>(
|
||||||
rows: impl Into<DimRange>,
|
value_strategy: T,
|
||||||
cols: impl Into<DimRange>,
|
rows: impl Into<DimRange>,
|
||||||
max_nonzeros: usize)
|
cols: impl Into<DimRange>,
|
||||||
-> impl Strategy<Value=CscMatrix<T::Value>>
|
max_nonzeros: usize,
|
||||||
where
|
) -> impl Strategy<Value = CscMatrix<T::Value>>
|
||||||
T: Strategy + Clone + 'static,
|
where
|
||||||
T::Value: Scalar,
|
T: Strategy + Clone + 'static,
|
||||||
|
T::Value: Scalar,
|
||||||
{
|
{
|
||||||
let rows = rows.into();
|
let rows = rows.into();
|
||||||
let cols = cols.into();
|
let cols = cols.into();
|
||||||
sparsity_pattern(cols.lower_bound().value() ..= cols.upper_bound().value(), rows.lower_bound().value() ..= rows.upper_bound().value(), max_nonzeros)
|
sparsity_pattern(
|
||||||
.prop_flat_map(move |pattern| {
|
cols.lower_bound().value()..=cols.upper_bound().value(),
|
||||||
let nnz = pattern.nnz();
|
rows.lower_bound().value()..=rows.upper_bound().value(),
|
||||||
let values = vec![value_strategy.clone(); nnz];
|
max_nonzeros,
|
||||||
(Just(pattern), values)
|
)
|
||||||
})
|
.prop_flat_map(move |pattern| {
|
||||||
.prop_map(|(pattern, values)| {
|
let nnz = pattern.nnz();
|
||||||
CscMatrix::try_from_pattern_and_values(pattern, values)
|
let values = vec![value_strategy.clone(); nnz];
|
||||||
.expect("Internal error: Generated CscMatrix is invalid")
|
(Just(pattern), values)
|
||||||
})
|
})
|
||||||
}
|
.prop_map(|(pattern, values)| {
|
||||||
|
CscMatrix::try_from_pattern_and_values(pattern, values)
|
||||||
|
.expect("Internal error: Generated CscMatrix is invalid")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -22,19 +22,19 @@
|
||||||
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
use proptest::strategy::{Strategy, Shuffleable, NewTree, ValueTree};
|
|
||||||
use proptest::test_runner::{TestRunner, TestRng};
|
|
||||||
use std::cell::Cell;
|
|
||||||
use proptest::num;
|
use proptest::num;
|
||||||
use proptest::prelude::Rng;
|
use proptest::prelude::Rng;
|
||||||
|
use proptest::strategy::{NewTree, Shuffleable, Strategy, ValueTree};
|
||||||
|
use proptest::test_runner::{TestRng, TestRunner};
|
||||||
|
use std::cell::Cell;
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
#[must_use = "strategies do nothing unless used"]
|
#[must_use = "strategies do nothing unless used"]
|
||||||
pub struct Shuffle<S>(pub(super) S);
|
pub struct Shuffle<S>(pub(super) S);
|
||||||
|
|
||||||
impl<S: Strategy> Strategy for Shuffle<S>
|
impl<S: Strategy> Strategy for Shuffle<S>
|
||||||
where
|
where
|
||||||
S::Value: Shuffleable,
|
S::Value: Shuffleable,
|
||||||
{
|
{
|
||||||
type Tree = ShuffleValueTree<S::Tree>;
|
type Tree = ShuffleValueTree<S::Tree>;
|
||||||
type Value = S::Value;
|
type Value = S::Value;
|
||||||
|
@ -60,8 +60,8 @@ pub struct ShuffleValueTree<V> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V: ValueTree> ShuffleValueTree<V>
|
impl<V: ValueTree> ShuffleValueTree<V>
|
||||||
where
|
where
|
||||||
V::Value: Shuffleable,
|
V::Value: Shuffleable,
|
||||||
{
|
{
|
||||||
fn init_dist(&self, dflt: usize) -> usize {
|
fn init_dist(&self, dflt: usize) -> usize {
|
||||||
if self.dist.get().is_none() {
|
if self.dist.get().is_none() {
|
||||||
|
@ -79,8 +79,8 @@ impl<V: ValueTree> ShuffleValueTree<V>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V: ValueTree> ValueTree for ShuffleValueTree<V>
|
impl<V: ValueTree> ValueTree for ShuffleValueTree<V>
|
||||||
where
|
where
|
||||||
V::Value: Shuffleable,
|
V::Value: Shuffleable,
|
||||||
{
|
{
|
||||||
type Value = V::Value;
|
type Value = V::Value;
|
||||||
|
|
||||||
|
@ -143,4 +143,4 @@ impl<V: ValueTree> ValueTree for ShuffleValueTree<V>
|
||||||
self.dist.get_mut().as_mut().unwrap().complicate()
|
self.dist.get_mut().as_mut().unwrap().complicate()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
use proptest::strategy::Strategy;
|
|
||||||
use nalgebra_sparse::csr::CsrMatrix;
|
|
||||||
use nalgebra_sparse::proptest::{csr, csc};
|
|
||||||
use nalgebra_sparse::csc::CscMatrix;
|
use nalgebra_sparse::csc::CscMatrix;
|
||||||
use std::ops::RangeInclusive;
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
use std::convert::{TryFrom};
|
use nalgebra_sparse::proptest::{csc, csr};
|
||||||
|
use proptest::strategy::Strategy;
|
||||||
|
use std::convert::TryFrom;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
use std::ops::RangeInclusive;
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! assert_panics {
|
macro_rules! assert_panics {
|
||||||
($e:expr) => {{
|
($e:expr) => {{
|
||||||
use std::panic::{catch_unwind};
|
use std::panic::catch_unwind;
|
||||||
use std::stringify;
|
use std::stringify;
|
||||||
let expr_string = stringify!($e);
|
let expr_string = stringify!($e);
|
||||||
|
|
||||||
|
@ -22,37 +22,56 @@ macro_rules! assert_panics {
|
||||||
|
|
||||||
let result = catch_unwind(|| $e);
|
let result = catch_unwind(|| $e);
|
||||||
if result.is_ok() {
|
if result.is_ok() {
|
||||||
panic!("assert_panics!({}) failed: the expression did not panic.", expr_string);
|
panic!(
|
||||||
|
"assert_panics!({}) failed: the expression did not panic.",
|
||||||
|
expr_string
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}};
|
}};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const PROPTEST_MATRIX_DIM: RangeInclusive<usize> = 0..=6;
|
pub const PROPTEST_MATRIX_DIM: RangeInclusive<usize> = 0..=6;
|
||||||
pub const PROPTEST_MAX_NNZ: usize = 40;
|
pub const PROPTEST_MAX_NNZ: usize = 40;
|
||||||
pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5 ..= 5;
|
pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5..=5;
|
||||||
|
|
||||||
pub fn value_strategy<T>() -> RangeInclusive<T>
|
pub fn value_strategy<T>() -> RangeInclusive<T>
|
||||||
where
|
where
|
||||||
T: TryFrom<i32>,
|
T: TryFrom<i32>,
|
||||||
T::Error: Debug
|
T::Error: Debug,
|
||||||
{
|
{
|
||||||
let (start, end) = (PROPTEST_I32_VALUE_STRATEGY.start(), PROPTEST_I32_VALUE_STRATEGY.end());
|
let (start, end) = (
|
||||||
T::try_from(*start).unwrap() ..= T::try_from(*end).unwrap()
|
PROPTEST_I32_VALUE_STRATEGY.start(),
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY.end(),
|
||||||
|
);
|
||||||
|
T::try_from(*start).unwrap()..=T::try_from(*end).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn non_zero_i32_value_strategy() -> impl Strategy<Value=i32> {
|
pub fn non_zero_i32_value_strategy() -> impl Strategy<Value = i32> {
|
||||||
let (start, end) = (PROPTEST_I32_VALUE_STRATEGY.start(), PROPTEST_I32_VALUE_STRATEGY.end());
|
let (start, end) = (
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY.start(),
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY.end(),
|
||||||
|
);
|
||||||
assert!(start < &0);
|
assert!(start < &0);
|
||||||
assert!(end > &0);
|
assert!(end > &0);
|
||||||
// Note: we don't use RangeInclusive for the second range, because then we'd have different
|
// Note: we don't use RangeInclusive for the second range, because then we'd have different
|
||||||
// types, which would require boxing
|
// types, which would require boxing
|
||||||
(*start .. 0).prop_union(1 .. *end + 1)
|
(*start..0).prop_union(1..*end + 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
pub fn csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
|
||||||
csr(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
csr(
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
PROPTEST_MATRIX_DIM,
|
||||||
|
PROPTEST_MATRIX_DIM,
|
||||||
|
PROPTEST_MAX_NNZ,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn csc_strategy() -> impl Strategy<Value=CscMatrix<i32>> {
|
pub fn csc_strategy() -> impl Strategy<Value = CscMatrix<i32>> {
|
||||||
csc(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
csc(
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
PROPTEST_MATRIX_DIM,
|
||||||
|
PROPTEST_MATRIX_DIM,
|
||||||
|
PROPTEST_MAX_NNZ,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,4 +5,4 @@ compile_error!("Tests must be run with features `proptest-support` and `compare`
|
||||||
mod unit_tests;
|
mod unit_tests;
|
||||||
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
pub mod common;
|
pub mod common;
|
||||||
|
|
|
@ -1,17 +1,16 @@
|
||||||
use nalgebra_sparse::coo::CooMatrix;
|
|
||||||
use nalgebra_sparse::convert::serial::{convert_coo_dense, convert_coo_csr,
|
|
||||||
convert_dense_coo, convert_csr_dense,
|
|
||||||
convert_csr_coo, convert_dense_csr,
|
|
||||||
convert_csc_coo, convert_coo_csc,
|
|
||||||
convert_csc_dense, convert_dense_csc,
|
|
||||||
convert_csr_csc, convert_csc_csr};
|
|
||||||
use nalgebra_sparse::proptest::{coo_with_duplicates, coo_no_duplicates, csr, csc};
|
|
||||||
use nalgebra::proptest::matrix;
|
|
||||||
use proptest::prelude::*;
|
|
||||||
use nalgebra::DMatrix;
|
|
||||||
use nalgebra_sparse::csr::CsrMatrix;
|
|
||||||
use nalgebra_sparse::csc::CscMatrix;
|
|
||||||
use crate::common::csc_strategy;
|
use crate::common::csc_strategy;
|
||||||
|
use nalgebra::proptest::matrix;
|
||||||
|
use nalgebra::DMatrix;
|
||||||
|
use nalgebra_sparse::convert::serial::{
|
||||||
|
convert_coo_csc, convert_coo_csr, convert_coo_dense, convert_csc_coo, convert_csc_csr,
|
||||||
|
convert_csc_dense, convert_csr_coo, convert_csr_csc, convert_csr_dense, convert_dense_coo,
|
||||||
|
convert_dense_csc, convert_dense_csr,
|
||||||
|
};
|
||||||
|
use nalgebra_sparse::coo::CooMatrix;
|
||||||
|
use nalgebra_sparse::csc::CscMatrix;
|
||||||
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
use nalgebra_sparse::proptest::{coo_no_duplicates, coo_with_duplicates, csc, csr};
|
||||||
|
use proptest::prelude::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_convert_dense_coo() {
|
fn test_convert_dense_coo() {
|
||||||
|
@ -41,16 +40,17 @@ fn test_convert_dense_coo() {
|
||||||
// Here we implicitly test that the coo matrix is indeed constructed from column-major
|
// Here we implicitly test that the coo matrix is indeed constructed from column-major
|
||||||
// iteration of the dense matrix.
|
// iteration of the dense matrix.
|
||||||
let dense = DMatrix::from_row_slice(2, 3, entries);
|
let dense = DMatrix::from_row_slice(2, 3, entries);
|
||||||
let coo_no_dup = CooMatrix::try_from_triplets(2, 3,
|
let coo_no_dup =
|
||||||
vec![0, 1, 0],
|
CooMatrix::try_from_triplets(2, 3, vec![0, 1, 0], vec![0, 1, 2], vec![1, 5, 3])
|
||||||
vec![0, 1, 2],
|
.unwrap();
|
||||||
vec![1, 5, 3])
|
let coo_dup = CooMatrix::try_from_triplets(
|
||||||
.unwrap();
|
2,
|
||||||
let coo_dup = CooMatrix::try_from_triplets(2, 3,
|
3,
|
||||||
vec![0, 1, 0, 1],
|
vec![0, 1, 0, 1],
|
||||||
vec![0, 1, 2, 1],
|
vec![0, 1, 2, 1],
|
||||||
vec![1, -2, 3, 7])
|
vec![1, -2, 3, 7],
|
||||||
.unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(CooMatrix::from(&dense), coo_no_dup);
|
assert_eq!(CooMatrix::from(&dense), coo_no_dup);
|
||||||
assert_eq!(DMatrix::from(&coo_dup), dense);
|
assert_eq!(DMatrix::from(&coo_dup), dense);
|
||||||
|
@ -76,8 +76,9 @@ fn test_convert_coo_csr() {
|
||||||
4,
|
4,
|
||||||
vec![0, 1, 2, 5],
|
vec![0, 1, 2, 5],
|
||||||
vec![1, 3, 0, 2, 3],
|
vec![1, 3, 0, 2, 3],
|
||||||
vec![2, 4, 1, 1, 2]
|
vec![2, 4, 1, 1, 2],
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(convert_coo_csr(&coo), expected_csr);
|
assert_eq!(convert_coo_csr(&coo), expected_csr);
|
||||||
}
|
}
|
||||||
|
@ -101,8 +102,9 @@ fn test_convert_coo_csr() {
|
||||||
4,
|
4,
|
||||||
vec![0, 1, 2, 5],
|
vec![0, 1, 2, 5],
|
||||||
vec![1, 3, 0, 2, 3],
|
vec![1, 3, 0, 2, 3],
|
||||||
vec![5, 4, 1, 1, 4]
|
vec![5, 4, 1, 1, 4],
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(convert_coo_csr(&coo), expected_csr);
|
assert_eq!(convert_coo_csr(&coo), expected_csr);
|
||||||
}
|
}
|
||||||
|
@ -115,16 +117,18 @@ fn test_convert_csr_coo() {
|
||||||
4,
|
4,
|
||||||
vec![0, 1, 2, 5],
|
vec![0, 1, 2, 5],
|
||||||
vec![1, 3, 0, 2, 3],
|
vec![1, 3, 0, 2, 3],
|
||||||
vec![5, 4, 1, 1, 4]
|
vec![5, 4, 1, 1, 4],
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let expected_coo = CooMatrix::try_from_triplets(
|
let expected_coo = CooMatrix::try_from_triplets(
|
||||||
3,
|
3,
|
||||||
4,
|
4,
|
||||||
vec![0, 1, 2, 2, 2],
|
vec![0, 1, 2, 2, 2],
|
||||||
vec![1, 3, 0, 2, 3],
|
vec![1, 3, 0, 2, 3],
|
||||||
vec![5, 4, 1, 1, 4]
|
vec![5, 4, 1, 1, 4],
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(convert_csr_coo(&csr), expected_coo);
|
assert_eq!(convert_csr_coo(&csr), expected_coo);
|
||||||
}
|
}
|
||||||
|
@ -148,8 +152,9 @@ fn test_convert_coo_csc() {
|
||||||
4,
|
4,
|
||||||
vec![0, 1, 2, 3, 5],
|
vec![0, 1, 2, 3, 5],
|
||||||
vec![2, 0, 2, 1, 2],
|
vec![2, 0, 2, 1, 2],
|
||||||
vec![1, 2, 1, 4, 2]
|
vec![1, 2, 1, 4, 2],
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(convert_coo_csc(&coo), expected_csc);
|
assert_eq!(convert_coo_csc(&coo), expected_csc);
|
||||||
}
|
}
|
||||||
|
@ -173,8 +178,9 @@ fn test_convert_coo_csc() {
|
||||||
4,
|
4,
|
||||||
vec![0, 1, 2, 3, 5],
|
vec![0, 1, 2, 3, 5],
|
||||||
vec![2, 0, 2, 1, 2],
|
vec![2, 0, 2, 1, 2],
|
||||||
vec![1, 5, 1, 4, 4]
|
vec![1, 5, 1, 4, 4],
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(convert_coo_csc(&coo), expected_csc);
|
assert_eq!(convert_coo_csc(&coo), expected_csc);
|
||||||
}
|
}
|
||||||
|
@ -187,16 +193,18 @@ fn test_convert_csc_coo() {
|
||||||
4,
|
4,
|
||||||
vec![0, 1, 2, 3, 5],
|
vec![0, 1, 2, 3, 5],
|
||||||
vec![2, 0, 2, 1, 2],
|
vec![2, 0, 2, 1, 2],
|
||||||
vec![1, 2, 1, 4, 2]
|
vec![1, 2, 1, 4, 2],
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let expected_coo = CooMatrix::try_from_triplets(
|
let expected_coo = CooMatrix::try_from_triplets(
|
||||||
3,
|
3,
|
||||||
4,
|
4,
|
||||||
vec![2, 0, 2, 1, 2],
|
vec![2, 0, 2, 1, 2],
|
||||||
vec![0, 1, 2, 3, 3],
|
vec![0, 1, 2, 3, 3],
|
||||||
vec![1, 2, 1, 4, 2]
|
vec![1, 2, 1, 4, 2],
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(convert_csc_coo(&csc), expected_coo);
|
assert_eq!(convert_csc_coo(&csc), expected_coo);
|
||||||
}
|
}
|
||||||
|
@ -209,7 +217,8 @@ fn test_convert_csr_csc_bidirectional() {
|
||||||
vec![0, 3, 4, 6],
|
vec![0, 3, 4, 6],
|
||||||
vec![1, 2, 3, 0, 1, 3],
|
vec![1, 2, 3, 0, 1, 3],
|
||||||
vec![5, 3, 2, 2, 1, 4],
|
vec![5, 3, 2, 2, 1, 4],
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let csc = CscMatrix::try_from_csc_data(
|
let csc = CscMatrix::try_from_csc_data(
|
||||||
3,
|
3,
|
||||||
|
@ -217,7 +226,8 @@ fn test_convert_csr_csc_bidirectional() {
|
||||||
vec![0, 1, 3, 4, 6],
|
vec![0, 1, 3, 4, 6],
|
||||||
vec![1, 0, 2, 0, 0, 2],
|
vec![1, 0, 2, 0, 0, 2],
|
||||||
vec![2, 5, 1, 3, 2, 4],
|
vec![2, 5, 1, 3, 2, 4],
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(convert_csr_csc(&csr), csc);
|
assert_eq!(convert_csr_csc(&csr), csc);
|
||||||
assert_eq!(convert_csc_csr(&csc), csr);
|
assert_eq!(convert_csc_csr(&csc), csr);
|
||||||
|
@ -231,7 +241,8 @@ fn test_convert_csr_dense_bidirectional() {
|
||||||
vec![0, 3, 4, 6],
|
vec![0, 3, 4, 6],
|
||||||
vec![1, 2, 3, 0, 1, 3],
|
vec![1, 2, 3, 0, 1, 3],
|
||||||
vec![5, 3, 2, 2, 1, 4],
|
vec![5, 3, 2, 2, 1, 4],
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
let dense = DMatrix::from_row_slice(3, 4, &[
|
let dense = DMatrix::from_row_slice(3, 4, &[
|
||||||
|
@ -252,7 +263,8 @@ fn test_convert_csc_dense_bidirectional() {
|
||||||
vec![0, 1, 3, 4, 6],
|
vec![0, 1, 3, 4, 6],
|
||||||
vec![1, 0, 2, 0, 0, 2],
|
vec![1, 0, 2, 0, 0, 2],
|
||||||
vec![2, 5, 1, 3, 2, 4],
|
vec![2, 5, 1, 3, 2, 4],
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
let dense = DMatrix::from_row_slice(3, 4, &[
|
let dense = DMatrix::from_row_slice(3, 4, &[
|
||||||
|
@ -265,29 +277,29 @@ fn test_convert_csc_dense_bidirectional() {
|
||||||
assert_eq!(convert_dense_csc(&dense), csc);
|
assert_eq!(convert_dense_csc(&dense), csc);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn coo_strategy() -> impl Strategy<Value=CooMatrix<i32>> {
|
fn coo_strategy() -> impl Strategy<Value = CooMatrix<i32>> {
|
||||||
coo_with_duplicates(-5 ..= 5, 0..=6usize, 0..=6usize, 40, 2)
|
coo_with_duplicates(-5..=5, 0..=6usize, 0..=6usize, 40, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn coo_no_duplicates_strategy() -> impl Strategy<Value=CooMatrix<i32>> {
|
fn coo_no_duplicates_strategy() -> impl Strategy<Value = CooMatrix<i32>> {
|
||||||
coo_no_duplicates(-5 ..= 5, 0..=6usize, 0..=6usize, 40)
|
coo_no_duplicates(-5..=5, 0..=6usize, 0..=6usize, 40)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
fn csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
|
||||||
csr(-5 ..= 5, 0..=6usize, 0..=6usize, 40)
|
csr(-5..=5, 0..=6usize, 0..=6usize, 40)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns
|
/// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns
|
||||||
fn non_zero_csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
fn non_zero_csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
|
||||||
csr(1 ..= 5, 0..=6usize, 0..=6usize, 40)
|
csr(1..=5, 0..=6usize, 0..=6usize, 40)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns
|
/// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns
|
||||||
fn non_zero_csc_strategy() -> impl Strategy<Value=CscMatrix<i32>> {
|
fn non_zero_csc_strategy() -> impl Strategy<Value = CscMatrix<i32>> {
|
||||||
csc(1 ..= 5, 0..=6usize, 0..=6usize, 40)
|
csc(1..=5, 0..=6usize, 0..=6usize, 40)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
fn dense_strategy() -> impl Strategy<Value = DMatrix<i32>> {
|
||||||
matrix(-5..=5, 0..=6, 0..=6)
|
matrix(-5..=5, 0..=6, 0..=6)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -437,4 +449,4 @@ proptest! {
|
||||||
fn csr_from_csc_roundtrip(csc in csc_strategy()) {
|
fn csr_from_csc_roundtrip(csc in csc_strategy()) {
|
||||||
prop_assert_eq!(&csc, &CscMatrix::from(&CsrMatrix::from(&csc)));
|
prop_assert_eq!(&csc, &CscMatrix::from(&CsrMatrix::from(&csc)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use nalgebra_sparse::{SparseFormatErrorKind};
|
|
||||||
use nalgebra_sparse::coo::CooMatrix;
|
|
||||||
use nalgebra::DMatrix;
|
|
||||||
use crate::assert_panics;
|
use crate::assert_panics;
|
||||||
|
use nalgebra::DMatrix;
|
||||||
|
use nalgebra_sparse::coo::CooMatrix;
|
||||||
|
use nalgebra_sparse::SparseFormatErrorKind;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn coo_construction_for_valid_data() {
|
fn coo_construction_for_valid_data() {
|
||||||
|
@ -10,8 +10,8 @@ fn coo_construction_for_valid_data() {
|
||||||
|
|
||||||
{
|
{
|
||||||
// Zero matrix
|
// Zero matrix
|
||||||
let coo = CooMatrix::<i32>::try_from_triplets(3, 2, Vec::new(), Vec::new(), Vec::new())
|
let coo =
|
||||||
.unwrap();
|
CooMatrix::<i32>::try_from_triplets(3, 2, Vec::new(), Vec::new(), Vec::new()).unwrap();
|
||||||
assert_eq!(coo.nrows(), 3);
|
assert_eq!(coo.nrows(), 3);
|
||||||
assert_eq!(coo.ncols(), 2);
|
assert_eq!(coo.ncols(), 2);
|
||||||
assert!(coo.triplet_iter().next().is_none());
|
assert!(coo.triplet_iter().next().is_none());
|
||||||
|
@ -27,8 +27,8 @@ fn coo_construction_for_valid_data() {
|
||||||
let i = vec![0, 1, 0, 0, 2];
|
let i = vec![0, 1, 0, 0, 2];
|
||||||
let j = vec![0, 2, 1, 3, 3];
|
let j = vec![0, 2, 1, 3, 3];
|
||||||
let v = vec![2, 3, 7, 3, 1];
|
let v = vec![2, 3, 7, 3, 1];
|
||||||
let coo = CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone())
|
let coo =
|
||||||
.unwrap();
|
CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap();
|
||||||
assert_eq!(coo.nrows(), 3);
|
assert_eq!(coo.nrows(), 3);
|
||||||
assert_eq!(coo.ncols(), 5);
|
assert_eq!(coo.ncols(), 5);
|
||||||
|
|
||||||
|
@ -59,8 +59,8 @@ fn coo_construction_for_valid_data() {
|
||||||
let i = vec![0, 1, 0, 0, 0, 0, 2, 1];
|
let i = vec![0, 1, 0, 0, 0, 0, 2, 1];
|
||||||
let j = vec![0, 2, 0, 1, 0, 3, 3, 2];
|
let j = vec![0, 2, 0, 1, 0, 3, 3, 2];
|
||||||
let v = vec![2, 3, 4, 7, 1, 3, 1, 5];
|
let v = vec![2, 3, 4, 7, 1, 3, 1, 5];
|
||||||
let coo = CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone())
|
let coo =
|
||||||
.unwrap();
|
CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap();
|
||||||
assert_eq!(coo.nrows(), 3);
|
assert_eq!(coo.nrows(), 3);
|
||||||
assert_eq!(coo.ncols(), 5);
|
assert_eq!(coo.ncols(), 5);
|
||||||
|
|
||||||
|
@ -92,25 +92,37 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
|
||||||
{
|
{
|
||||||
// 0x0 matrix
|
// 0x0 matrix
|
||||||
let result = CooMatrix::<i32>::try_from_triplets(0, 0, vec![0], vec![0], vec![2]);
|
let result = CooMatrix::<i32>::try_from_triplets(0, 0, vec![0], vec![0], vec![2]);
|
||||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
assert!(matches!(
|
||||||
|
result.unwrap_err().kind(),
|
||||||
|
SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// 1x1 matrix, row out of bounds
|
// 1x1 matrix, row out of bounds
|
||||||
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![0], vec![2]);
|
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![0], vec![2]);
|
||||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
assert!(matches!(
|
||||||
|
result.unwrap_err().kind(),
|
||||||
|
SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// 1x1 matrix, col out of bounds
|
// 1x1 matrix, col out of bounds
|
||||||
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![0], vec![1], vec![2]);
|
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![0], vec![1], vec![2]);
|
||||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
assert!(matches!(
|
||||||
|
result.unwrap_err().kind(),
|
||||||
|
SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// 1x1 matrix, row and col out of bounds
|
// 1x1 matrix, row and col out of bounds
|
||||||
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![1], vec![2]);
|
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![1], vec![2]);
|
||||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
assert!(matches!(
|
||||||
|
result.unwrap_err().kind(),
|
||||||
|
SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -119,7 +131,10 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
|
||||||
let j = vec![0, 2, 1, 3, 3];
|
let j = vec![0, 2, 1, 3, 3];
|
||||||
let v = vec![2, 3, 7, 3, 1];
|
let v = vec![2, 3, 7, 3, 1];
|
||||||
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
||||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
assert!(matches!(
|
||||||
|
result.unwrap_err().kind(),
|
||||||
|
SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -128,7 +143,10 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
|
||||||
let j = vec![0, 2, 1, 5, 3];
|
let j = vec![0, 2, 1, 5, 3];
|
||||||
let v = vec![2, 3, 7, 3, 1];
|
let v = vec![2, 3, 7, 3, 1];
|
||||||
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
||||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
assert!(matches!(
|
||||||
|
result.unwrap_err().kind(),
|
||||||
|
SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,16 +155,55 @@ fn coo_try_from_triplets_panics_on_mismatched_vectors() {
|
||||||
// Check that try_from_triplets panics when the triplet vectors have different lengths
|
// Check that try_from_triplets panics when the triplet vectors have different lengths
|
||||||
macro_rules! assert_errs {
|
macro_rules! assert_errs {
|
||||||
($result:expr) => {
|
($result:expr) => {
|
||||||
assert!(matches!($result.unwrap_err().kind(), SparseFormatErrorKind::InvalidStructure))
|
assert!(matches!(
|
||||||
}
|
$result.unwrap_err().kind(),
|
||||||
|
SparseFormatErrorKind::InvalidStructure
|
||||||
|
))
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1, 2], vec![0], vec![0]));
|
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1], vec![0, 0], vec![0]));
|
3,
|
||||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1], vec![0], vec![0, 1]));
|
5,
|
||||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1, 2], vec![0, 1], vec![0]));
|
vec![1, 2],
|
||||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1], vec![0, 1], vec![0, 1]));
|
vec![0],
|
||||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1, 1], vec![0], vec![0, 1]));
|
vec![0]
|
||||||
|
));
|
||||||
|
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
vec![1],
|
||||||
|
vec![0, 0],
|
||||||
|
vec![0]
|
||||||
|
));
|
||||||
|
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
vec![1],
|
||||||
|
vec![0],
|
||||||
|
vec![0, 1]
|
||||||
|
));
|
||||||
|
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
vec![1, 2],
|
||||||
|
vec![0, 1],
|
||||||
|
vec![0]
|
||||||
|
));
|
||||||
|
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
vec![1],
|
||||||
|
vec![0, 1],
|
||||||
|
vec![0, 1]
|
||||||
|
));
|
||||||
|
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
vec![1, 1],
|
||||||
|
vec![0],
|
||||||
|
vec![0, 1]
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -157,10 +214,16 @@ fn coo_push_valid_entries() {
|
||||||
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1)]);
|
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1)]);
|
||||||
|
|
||||||
coo.push(0, 0, 2);
|
coo.push(0, 0, 2);
|
||||||
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1), (0, 0, &2)]);
|
assert_eq!(
|
||||||
|
coo.triplet_iter().collect::<Vec<_>>(),
|
||||||
|
vec![(0, 0, &1), (0, 0, &2)]
|
||||||
|
);
|
||||||
|
|
||||||
coo.push(2, 2, 3);
|
coo.push(2, 2, 3);
|
||||||
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1), (0, 0, &2), (2, 2, &3)]);
|
assert_eq!(
|
||||||
|
coo.triplet_iter().collect::<Vec<_>>(),
|
||||||
|
vec![(0, 0, &1), (0, 0, &2), (2, 2, &3)]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -188,4 +251,4 @@ fn coo_push_out_of_bounds_entries() {
|
||||||
assert_panics!(coo.clone().push(2, 2, 1));
|
assert_panics!(coo.clone().push(2, 2, 1));
|
||||||
assert_panics!(coo.clone().push(3, 2, 1));
|
assert_panics!(coo.clone().push(3, 2, 1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
|
use nalgebra::DMatrix;
|
||||||
use nalgebra_sparse::csc::CscMatrix;
|
use nalgebra_sparse::csc::CscMatrix;
|
||||||
use nalgebra_sparse::SparseFormatErrorKind;
|
use nalgebra_sparse::SparseFormatErrorKind;
|
||||||
use nalgebra::DMatrix;
|
|
||||||
|
|
||||||
use proptest::prelude::*;
|
use proptest::prelude::*;
|
||||||
use proptest::sample::subsequence;
|
use proptest::sample::subsequence;
|
||||||
|
@ -42,7 +42,10 @@ fn csc_matrix_valid_data() {
|
||||||
assert_eq!(matrix.col_mut(0).row_indices(), &[]);
|
assert_eq!(matrix.col_mut(0).row_indices(), &[]);
|
||||||
assert_eq!(matrix.col_mut(0).values(), &[]);
|
assert_eq!(matrix.col_mut(0).values(), &[]);
|
||||||
assert_eq!(matrix.col_mut(0).values_mut(), &[]);
|
assert_eq!(matrix.col_mut(0).values_mut(), &[]);
|
||||||
assert_eq!(matrix.col_mut(0).rows_and_values_mut(), ([].as_ref(), [].as_mut()));
|
assert_eq!(
|
||||||
|
matrix.col_mut(0).rows_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(matrix.col(1).nrows(), 2);
|
assert_eq!(matrix.col(1).nrows(), 2);
|
||||||
assert_eq!(matrix.col(1).nnz(), 0);
|
assert_eq!(matrix.col(1).nnz(), 0);
|
||||||
|
@ -53,7 +56,10 @@ fn csc_matrix_valid_data() {
|
||||||
assert_eq!(matrix.col_mut(1).row_indices(), &[]);
|
assert_eq!(matrix.col_mut(1).row_indices(), &[]);
|
||||||
assert_eq!(matrix.col_mut(1).values(), &[]);
|
assert_eq!(matrix.col_mut(1).values(), &[]);
|
||||||
assert_eq!(matrix.col_mut(1).values_mut(), &[]);
|
assert_eq!(matrix.col_mut(1).values_mut(), &[]);
|
||||||
assert_eq!(matrix.col_mut(1).rows_and_values_mut(), ([].as_ref(), [].as_mut()));
|
assert_eq!(
|
||||||
|
matrix.col_mut(1).rows_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(matrix.col(2).nrows(), 2);
|
assert_eq!(matrix.col(2).nrows(), 2);
|
||||||
assert_eq!(matrix.col(2).nnz(), 0);
|
assert_eq!(matrix.col(2).nnz(), 0);
|
||||||
|
@ -64,7 +70,10 @@ fn csc_matrix_valid_data() {
|
||||||
assert_eq!(matrix.col_mut(2).row_indices(), &[]);
|
assert_eq!(matrix.col_mut(2).row_indices(), &[]);
|
||||||
assert_eq!(matrix.col_mut(2).values(), &[]);
|
assert_eq!(matrix.col_mut(2).values(), &[]);
|
||||||
assert_eq!(matrix.col_mut(2).values_mut(), &[]);
|
assert_eq!(matrix.col_mut(2).values_mut(), &[]);
|
||||||
assert_eq!(matrix.col_mut(2).rows_and_values_mut(), ([].as_ref(), [].as_mut()));
|
assert_eq!(
|
||||||
|
matrix.col_mut(2).rows_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
assert!(matrix.get_col(3).is_none());
|
assert!(matrix.get_col(3).is_none());
|
||||||
assert!(matrix.get_col_mut(3).is_none());
|
assert!(matrix.get_col_mut(3).is_none());
|
||||||
|
@ -81,11 +90,9 @@ fn csc_matrix_valid_data() {
|
||||||
let offsets = vec![0, 2, 2, 5];
|
let offsets = vec![0, 2, 2, 5];
|
||||||
let indices = vec![0, 5, 1, 2, 3];
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let mut matrix = CscMatrix::try_from_csc_data(6,
|
let mut matrix =
|
||||||
3,
|
CscMatrix::try_from_csc_data(6, 3, offsets.clone(), indices.clone(), values.clone())
|
||||||
offsets.clone(),
|
.unwrap();
|
||||||
indices.clone(),
|
|
||||||
values.clone()).unwrap();
|
|
||||||
|
|
||||||
assert_eq!(matrix.nrows(), 6);
|
assert_eq!(matrix.nrows(), 6);
|
||||||
assert_eq!(matrix.ncols(), 3);
|
assert_eq!(matrix.ncols(), 3);
|
||||||
|
@ -95,10 +102,20 @@ fn csc_matrix_valid_data() {
|
||||||
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
|
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
|
||||||
|
|
||||||
let expected_triplets = vec![(0, 0, 0), (5, 0, 1), (1, 2, 2), (2, 2, 3), (3, 2, 4)];
|
let expected_triplets = vec![(0, 0, 0), (5, 0, 1), (1, 2, 2), (2, 2, 3), (3, 2, 4)];
|
||||||
assert_eq!(matrix.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(),
|
assert_eq!(
|
||||||
expected_triplets);
|
matrix
|
||||||
assert_eq!(matrix.triplet_iter_mut().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(),
|
.triplet_iter()
|
||||||
expected_triplets);
|
.map(|(i, j, v)| (i, j, *v))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
expected_triplets
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
matrix
|
||||||
|
.triplet_iter_mut()
|
||||||
|
.map(|(i, j, v)| (i, j, *v))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
expected_triplets
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(matrix.col(0).nrows(), 6);
|
assert_eq!(matrix.col(0).nrows(), 6);
|
||||||
assert_eq!(matrix.col(0).nnz(), 2);
|
assert_eq!(matrix.col(0).nnz(), 2);
|
||||||
|
@ -109,7 +126,10 @@ fn csc_matrix_valid_data() {
|
||||||
assert_eq!(matrix.col_mut(0).row_indices(), &[0, 5]);
|
assert_eq!(matrix.col_mut(0).row_indices(), &[0, 5]);
|
||||||
assert_eq!(matrix.col_mut(0).values(), &[0, 1]);
|
assert_eq!(matrix.col_mut(0).values(), &[0, 1]);
|
||||||
assert_eq!(matrix.col_mut(0).values_mut(), &[0, 1]);
|
assert_eq!(matrix.col_mut(0).values_mut(), &[0, 1]);
|
||||||
assert_eq!(matrix.col_mut(0).rows_and_values_mut(), ([0, 5].as_ref(), [0, 1].as_mut()));
|
assert_eq!(
|
||||||
|
matrix.col_mut(0).rows_and_values_mut(),
|
||||||
|
([0, 5].as_ref(), [0, 1].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(matrix.col(1).nrows(), 6);
|
assert_eq!(matrix.col(1).nrows(), 6);
|
||||||
assert_eq!(matrix.col(1).nnz(), 0);
|
assert_eq!(matrix.col(1).nnz(), 0);
|
||||||
|
@ -120,7 +140,10 @@ fn csc_matrix_valid_data() {
|
||||||
assert_eq!(matrix.col_mut(1).row_indices(), &[]);
|
assert_eq!(matrix.col_mut(1).row_indices(), &[]);
|
||||||
assert_eq!(matrix.col_mut(1).values(), &[]);
|
assert_eq!(matrix.col_mut(1).values(), &[]);
|
||||||
assert_eq!(matrix.col_mut(1).values_mut(), &[]);
|
assert_eq!(matrix.col_mut(1).values_mut(), &[]);
|
||||||
assert_eq!(matrix.col_mut(1).rows_and_values_mut(), ([].as_ref(), [].as_mut()));
|
assert_eq!(
|
||||||
|
matrix.col_mut(1).rows_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(matrix.col(2).nrows(), 6);
|
assert_eq!(matrix.col(2).nrows(), 6);
|
||||||
assert_eq!(matrix.col(2).nnz(), 3);
|
assert_eq!(matrix.col(2).nnz(), 3);
|
||||||
|
@ -131,7 +154,10 @@ fn csc_matrix_valid_data() {
|
||||||
assert_eq!(matrix.col_mut(2).row_indices(), &[1, 2, 3]);
|
assert_eq!(matrix.col_mut(2).row_indices(), &[1, 2, 3]);
|
||||||
assert_eq!(matrix.col_mut(2).values(), &[2, 3, 4]);
|
assert_eq!(matrix.col_mut(2).values(), &[2, 3, 4]);
|
||||||
assert_eq!(matrix.col_mut(2).values_mut(), &[2, 3, 4]);
|
assert_eq!(matrix.col_mut(2).values_mut(), &[2, 3, 4]);
|
||||||
assert_eq!(matrix.col_mut(2).rows_and_values_mut(), ([1, 2, 3].as_ref(), [2, 3, 4].as_mut()));
|
assert_eq!(
|
||||||
|
matrix.col_mut(2).rows_and_values_mut(),
|
||||||
|
([1, 2, 3].as_ref(), [2, 3, 4].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
assert!(matrix.get_col(3).is_none());
|
assert!(matrix.get_col(3).is_none());
|
||||||
assert!(matrix.get_col_mut(3).is_none());
|
assert!(matrix.get_col_mut(3).is_none());
|
||||||
|
@ -146,11 +172,13 @@ fn csc_matrix_valid_data() {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn csc_matrix_try_from_invalid_csc_data() {
|
fn csc_matrix_try_from_invalid_csc_data() {
|
||||||
|
|
||||||
{
|
{
|
||||||
// Empty offset array (invalid length)
|
// Empty offset array (invalid length)
|
||||||
let matrix = CscMatrix::try_from_csc_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
|
let matrix = CscMatrix::try_from_csc_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -160,7 +188,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
|
||||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -169,7 +200,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
||||||
let indices = vec![0, 5, 1, 2, 3];
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -178,7 +212,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
||||||
let indices = vec![0, 5, 1, 2, 3];
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -187,7 +224,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
||||||
let indices = vec![0, 5, 1, 2, 3];
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -196,7 +236,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
||||||
let indices = vec![0, 1, 2, 3, 4];
|
let indices = vec![0, 1, 2, 3, 4];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -205,7 +248,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
||||||
let indices = vec![0, 2, 3, 1, 4];
|
let indices = vec![0, 2, 3, 1, 4];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -214,7 +260,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
||||||
let indices = vec![0, 6, 1, 2, 3];
|
let indices = vec![0, 6, 1, 2, 3];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::IndexOutOfBounds);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -223,9 +272,11 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
||||||
let indices = vec![0, 5, 2, 2, 3];
|
let indices = vec![0, 5, 2, 2, 3];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::DuplicateEntry);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::DuplicateEntry
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -239,11 +290,7 @@ fn csc_disassemble_avoids_clone_when_owned() {
|
||||||
let offsets_ptr = offsets.as_ptr();
|
let offsets_ptr = offsets.as_ptr();
|
||||||
let indices_ptr = indices.as_ptr();
|
let indices_ptr = indices.as_ptr();
|
||||||
let values_ptr = values.as_ptr();
|
let values_ptr = values.as_ptr();
|
||||||
let matrix = CscMatrix::try_from_csc_data(6,
|
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values).unwrap();
|
||||||
3,
|
|
||||||
offsets,
|
|
||||||
indices,
|
|
||||||
values).unwrap();
|
|
||||||
|
|
||||||
let (offsets, indices, values) = matrix.disassemble();
|
let (offsets, indices, values) = matrix.disassemble();
|
||||||
assert_eq!(offsets.as_ptr(), offsets_ptr);
|
assert_eq!(offsets.as_ptr(), offsets_ptr);
|
||||||
|
@ -328,4 +375,4 @@ proptest! {
|
||||||
|
|
||||||
prop_assert_eq!(d_entries, csc_diagonal_entries);
|
prop_assert_eq!(d_entries, csc_diagonal_entries);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
|
use nalgebra::DMatrix;
|
||||||
use nalgebra_sparse::csr::CsrMatrix;
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
use nalgebra_sparse::SparseFormatErrorKind;
|
use nalgebra_sparse::SparseFormatErrorKind;
|
||||||
use nalgebra::DMatrix;
|
|
||||||
|
|
||||||
use proptest::prelude::*;
|
use proptest::prelude::*;
|
||||||
use proptest::sample::subsequence;
|
use proptest::sample::subsequence;
|
||||||
|
@ -9,7 +9,6 @@ use crate::common::csr_strategy;
|
||||||
|
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn csr_matrix_valid_data() {
|
fn csr_matrix_valid_data() {
|
||||||
// Construct matrix from valid data and check that selected methods return results
|
// Construct matrix from valid data and check that selected methods return results
|
||||||
|
@ -43,7 +42,10 @@ fn csr_matrix_valid_data() {
|
||||||
assert_eq!(matrix.row_mut(0).col_indices(), &[]);
|
assert_eq!(matrix.row_mut(0).col_indices(), &[]);
|
||||||
assert_eq!(matrix.row_mut(0).values(), &[]);
|
assert_eq!(matrix.row_mut(0).values(), &[]);
|
||||||
assert_eq!(matrix.row_mut(0).values_mut(), &[]);
|
assert_eq!(matrix.row_mut(0).values_mut(), &[]);
|
||||||
assert_eq!(matrix.row_mut(0).cols_and_values_mut(), ([].as_ref(), [].as_mut()));
|
assert_eq!(
|
||||||
|
matrix.row_mut(0).cols_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(matrix.row(1).ncols(), 2);
|
assert_eq!(matrix.row(1).ncols(), 2);
|
||||||
assert_eq!(matrix.row(1).nnz(), 0);
|
assert_eq!(matrix.row(1).nnz(), 0);
|
||||||
|
@ -54,7 +56,10 @@ fn csr_matrix_valid_data() {
|
||||||
assert_eq!(matrix.row_mut(1).col_indices(), &[]);
|
assert_eq!(matrix.row_mut(1).col_indices(), &[]);
|
||||||
assert_eq!(matrix.row_mut(1).values(), &[]);
|
assert_eq!(matrix.row_mut(1).values(), &[]);
|
||||||
assert_eq!(matrix.row_mut(1).values_mut(), &[]);
|
assert_eq!(matrix.row_mut(1).values_mut(), &[]);
|
||||||
assert_eq!(matrix.row_mut(1).cols_and_values_mut(), ([].as_ref(), [].as_mut()));
|
assert_eq!(
|
||||||
|
matrix.row_mut(1).cols_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(matrix.row(2).ncols(), 2);
|
assert_eq!(matrix.row(2).ncols(), 2);
|
||||||
assert_eq!(matrix.row(2).nnz(), 0);
|
assert_eq!(matrix.row(2).nnz(), 0);
|
||||||
|
@ -65,7 +70,10 @@ fn csr_matrix_valid_data() {
|
||||||
assert_eq!(matrix.row_mut(2).col_indices(), &[]);
|
assert_eq!(matrix.row_mut(2).col_indices(), &[]);
|
||||||
assert_eq!(matrix.row_mut(2).values(), &[]);
|
assert_eq!(matrix.row_mut(2).values(), &[]);
|
||||||
assert_eq!(matrix.row_mut(2).values_mut(), &[]);
|
assert_eq!(matrix.row_mut(2).values_mut(), &[]);
|
||||||
assert_eq!(matrix.row_mut(2).cols_and_values_mut(), ([].as_ref(), [].as_mut()));
|
assert_eq!(
|
||||||
|
matrix.row_mut(2).cols_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
assert!(matrix.get_row(3).is_none());
|
assert!(matrix.get_row(3).is_none());
|
||||||
assert!(matrix.get_row_mut(3).is_none());
|
assert!(matrix.get_row_mut(3).is_none());
|
||||||
|
@ -82,11 +90,9 @@ fn csr_matrix_valid_data() {
|
||||||
let offsets = vec![0, 2, 2, 5];
|
let offsets = vec![0, 2, 2, 5];
|
||||||
let indices = vec![0, 5, 1, 2, 3];
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let mut matrix = CsrMatrix::try_from_csr_data(3,
|
let mut matrix =
|
||||||
6,
|
CsrMatrix::try_from_csr_data(3, 6, offsets.clone(), indices.clone(), values.clone())
|
||||||
offsets.clone(),
|
.unwrap();
|
||||||
indices.clone(),
|
|
||||||
values.clone()).unwrap();
|
|
||||||
|
|
||||||
assert_eq!(matrix.nrows(), 3);
|
assert_eq!(matrix.nrows(), 3);
|
||||||
assert_eq!(matrix.ncols(), 6);
|
assert_eq!(matrix.ncols(), 6);
|
||||||
|
@ -96,10 +102,20 @@ fn csr_matrix_valid_data() {
|
||||||
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
|
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
|
||||||
|
|
||||||
let expected_triplets = vec![(0, 0, 0), (0, 5, 1), (2, 1, 2), (2, 2, 3), (2, 3, 4)];
|
let expected_triplets = vec![(0, 0, 0), (0, 5, 1), (2, 1, 2), (2, 2, 3), (2, 3, 4)];
|
||||||
assert_eq!(matrix.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(),
|
assert_eq!(
|
||||||
expected_triplets);
|
matrix
|
||||||
assert_eq!(matrix.triplet_iter_mut().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(),
|
.triplet_iter()
|
||||||
expected_triplets);
|
.map(|(i, j, v)| (i, j, *v))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
expected_triplets
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
matrix
|
||||||
|
.triplet_iter_mut()
|
||||||
|
.map(|(i, j, v)| (i, j, *v))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
expected_triplets
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(matrix.row(0).ncols(), 6);
|
assert_eq!(matrix.row(0).ncols(), 6);
|
||||||
assert_eq!(matrix.row(0).nnz(), 2);
|
assert_eq!(matrix.row(0).nnz(), 2);
|
||||||
|
@ -110,7 +126,10 @@ fn csr_matrix_valid_data() {
|
||||||
assert_eq!(matrix.row_mut(0).col_indices(), &[0, 5]);
|
assert_eq!(matrix.row_mut(0).col_indices(), &[0, 5]);
|
||||||
assert_eq!(matrix.row_mut(0).values(), &[0, 1]);
|
assert_eq!(matrix.row_mut(0).values(), &[0, 1]);
|
||||||
assert_eq!(matrix.row_mut(0).values_mut(), &[0, 1]);
|
assert_eq!(matrix.row_mut(0).values_mut(), &[0, 1]);
|
||||||
assert_eq!(matrix.row_mut(0).cols_and_values_mut(), ([0, 5].as_ref(), [0, 1].as_mut()));
|
assert_eq!(
|
||||||
|
matrix.row_mut(0).cols_and_values_mut(),
|
||||||
|
([0, 5].as_ref(), [0, 1].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(matrix.row(1).ncols(), 6);
|
assert_eq!(matrix.row(1).ncols(), 6);
|
||||||
assert_eq!(matrix.row(1).nnz(), 0);
|
assert_eq!(matrix.row(1).nnz(), 0);
|
||||||
|
@ -121,7 +140,10 @@ fn csr_matrix_valid_data() {
|
||||||
assert_eq!(matrix.row_mut(1).col_indices(), &[]);
|
assert_eq!(matrix.row_mut(1).col_indices(), &[]);
|
||||||
assert_eq!(matrix.row_mut(1).values(), &[]);
|
assert_eq!(matrix.row_mut(1).values(), &[]);
|
||||||
assert_eq!(matrix.row_mut(1).values_mut(), &[]);
|
assert_eq!(matrix.row_mut(1).values_mut(), &[]);
|
||||||
assert_eq!(matrix.row_mut(1).cols_and_values_mut(), ([].as_ref(), [].as_mut()));
|
assert_eq!(
|
||||||
|
matrix.row_mut(1).cols_and_values_mut(),
|
||||||
|
([].as_ref(), [].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(matrix.row(2).ncols(), 6);
|
assert_eq!(matrix.row(2).ncols(), 6);
|
||||||
assert_eq!(matrix.row(2).nnz(), 3);
|
assert_eq!(matrix.row(2).nnz(), 3);
|
||||||
|
@ -132,7 +154,10 @@ fn csr_matrix_valid_data() {
|
||||||
assert_eq!(matrix.row_mut(2).col_indices(), &[1, 2, 3]);
|
assert_eq!(matrix.row_mut(2).col_indices(), &[1, 2, 3]);
|
||||||
assert_eq!(matrix.row_mut(2).values(), &[2, 3, 4]);
|
assert_eq!(matrix.row_mut(2).values(), &[2, 3, 4]);
|
||||||
assert_eq!(matrix.row_mut(2).values_mut(), &[2, 3, 4]);
|
assert_eq!(matrix.row_mut(2).values_mut(), &[2, 3, 4]);
|
||||||
assert_eq!(matrix.row_mut(2).cols_and_values_mut(), ([1, 2, 3].as_ref(), [2, 3, 4].as_mut()));
|
assert_eq!(
|
||||||
|
matrix.row_mut(2).cols_and_values_mut(),
|
||||||
|
([1, 2, 3].as_ref(), [2, 3, 4].as_mut())
|
||||||
|
);
|
||||||
|
|
||||||
assert!(matrix.get_row(3).is_none());
|
assert!(matrix.get_row(3).is_none());
|
||||||
assert!(matrix.get_row_mut(3).is_none());
|
assert!(matrix.get_row_mut(3).is_none());
|
||||||
|
@ -147,11 +172,13 @@ fn csr_matrix_valid_data() {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn csr_matrix_try_from_invalid_csr_data() {
|
fn csr_matrix_try_from_invalid_csr_data() {
|
||||||
|
|
||||||
{
|
{
|
||||||
// Empty offset array (invalid length)
|
// Empty offset array (invalid length)
|
||||||
let matrix = CsrMatrix::try_from_csr_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
|
let matrix = CsrMatrix::try_from_csr_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -161,7 +188,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
|
||||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -170,7 +200,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
||||||
let indices = vec![0, 5, 1, 2, 3];
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -179,7 +212,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
||||||
let indices = vec![0, 5, 1, 2, 3];
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -188,7 +224,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
||||||
let indices = vec![0, 5, 1, 2, 3];
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -197,7 +236,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
||||||
let indices = vec![0, 1, 2, 3, 4];
|
let indices = vec![0, 1, 2, 3, 4];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -206,7 +248,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
||||||
let indices = vec![0, 2, 3, 1, 4];
|
let indices = vec![0, 2, 3, 1, 4];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::InvalidStructure
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -215,7 +260,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
||||||
let indices = vec![0, 6, 1, 2, 3];
|
let indices = vec![0, 6, 1, 2, 3];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::IndexOutOfBounds);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::IndexOutOfBounds
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -224,9 +272,11 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
||||||
let indices = vec![0, 5, 2, 2, 3];
|
let indices = vec![0, 5, 2, 2, 3];
|
||||||
let values = vec![0, 1, 2, 3, 4];
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::DuplicateEntry);
|
assert_eq!(
|
||||||
|
matrix.unwrap_err().kind(),
|
||||||
|
&SparseFormatErrorKind::DuplicateEntry
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -240,11 +290,7 @@ fn csr_disassemble_avoids_clone_when_owned() {
|
||||||
let offsets_ptr = offsets.as_ptr();
|
let offsets_ptr = offsets.as_ptr();
|
||||||
let indices_ptr = indices.as_ptr();
|
let indices_ptr = indices.as_ptr();
|
||||||
let values_ptr = values.as_ptr();
|
let values_ptr = values.as_ptr();
|
||||||
let matrix = CsrMatrix::try_from_csr_data(3,
|
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values).unwrap();
|
||||||
6,
|
|
||||||
offsets,
|
|
||||||
indices,
|
|
||||||
values).unwrap();
|
|
||||||
|
|
||||||
let (offsets, indices, values) = matrix.disassemble();
|
let (offsets, indices, values) = matrix.disassemble();
|
||||||
assert_eq!(offsets.as_ptr(), offsets_ptr);
|
assert_eq!(offsets.as_ptr(), offsets_ptr);
|
||||||
|
@ -329,4 +375,4 @@ proptest! {
|
||||||
|
|
||||||
prop_assert_eq!(d_entries, csr_diagonal_entries);
|
prop_assert_eq!(d_entries, csr_diagonal_entries);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
mod coo;
|
|
||||||
mod cholesky;
|
mod cholesky;
|
||||||
|
mod convert_serial;
|
||||||
|
mod coo;
|
||||||
|
mod csc;
|
||||||
|
mod csr;
|
||||||
mod ops;
|
mod ops;
|
||||||
mod pattern;
|
mod pattern;
|
||||||
mod csr;
|
mod proptest;
|
||||||
mod csc;
|
|
||||||
mod convert_serial;
|
|
||||||
mod proptest;
|
|
||||||
|
|
|
@ -1,13 +1,19 @@
|
||||||
use crate::common::{csc_strategy, csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ, PROPTEST_I32_VALUE_STRATEGY, non_zero_i32_value_strategy, value_strategy};
|
use crate::common::{
|
||||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spadd_csr_prealloc, spadd_csc_prealloc, spmm_csr_prealloc, spmm_csc_prealloc, spsolve_csc_lower_triangular, spmm_csr_pattern};
|
csc_strategy, csr_strategy, non_zero_i32_value_strategy, value_strategy,
|
||||||
use nalgebra_sparse::ops::{Op};
|
PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
||||||
use nalgebra_sparse::csr::CsrMatrix;
|
};
|
||||||
use nalgebra_sparse::csc::CscMatrix;
|
use nalgebra_sparse::csc::CscMatrix;
|
||||||
use nalgebra_sparse::proptest::{csc, csr, sparsity_pattern};
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
use nalgebra_sparse::ops::serial::{
|
||||||
|
spadd_csc_prealloc, spadd_csr_prealloc, spadd_pattern, spmm_csc_dense, spmm_csc_prealloc,
|
||||||
|
spmm_csr_dense, spmm_csr_pattern, spmm_csr_prealloc, spsolve_csc_lower_triangular,
|
||||||
|
};
|
||||||
|
use nalgebra_sparse::ops::Op;
|
||||||
use nalgebra_sparse::pattern::SparsityPattern;
|
use nalgebra_sparse::pattern::SparsityPattern;
|
||||||
|
use nalgebra_sparse::proptest::{csc, csr, sparsity_pattern};
|
||||||
|
|
||||||
use nalgebra::{DMatrix, Scalar, DMatrixSliceMut, DMatrixSlice};
|
|
||||||
use nalgebra::proptest::{matrix, vector};
|
use nalgebra::proptest::{matrix, vector};
|
||||||
|
use nalgebra::{DMatrix, DMatrixSlice, DMatrixSliceMut, Scalar};
|
||||||
|
|
||||||
use proptest::prelude::*;
|
use proptest::prelude::*;
|
||||||
|
|
||||||
|
@ -17,19 +23,15 @@ use std::panic::catch_unwind;
|
||||||
|
|
||||||
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
|
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
|
||||||
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||||
let boolean_csr = CsrMatrix::try_from_pattern_and_values(
|
let boolean_csr =
|
||||||
pattern.clone(),
|
CsrMatrix::try_from_pattern_and_values(pattern.clone(), vec![1; pattern.nnz()]).unwrap();
|
||||||
vec![1; pattern.nnz()])
|
|
||||||
.unwrap();
|
|
||||||
DMatrix::from(&boolean_csr)
|
DMatrix::from(&boolean_csr)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Represents the sparsity pattern of a CSC matrix as a dense matrix with 0/1
|
/// Represents the sparsity pattern of a CSC matrix as a dense matrix with 0/1
|
||||||
fn dense_csc_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
fn dense_csc_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||||
let boolean_csc = CscMatrix::try_from_pattern_and_values(
|
let boolean_csc =
|
||||||
pattern.clone(),
|
CscMatrix::try_from_pattern_and_values(pattern.clone(), vec![1; pattern.nnz()]).unwrap();
|
||||||
vec![1; pattern.nnz()])
|
|
||||||
.unwrap();
|
|
||||||
DMatrix::from(&boolean_csc)
|
DMatrix::from(&boolean_csc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,7 +55,7 @@ struct SpmmCscDenseArgs<T: Scalar> {
|
||||||
|
|
||||||
/// Returns matrices C, A and B with compatible dimensions such that it can be used
|
/// Returns matrices C, A and B with compatible dimensions such that it can be used
|
||||||
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
|
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
|
||||||
fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>> {
|
fn spmm_csr_dense_args_strategy() -> impl Strategy<Value = SpmmCsrDenseArgs<i32>> {
|
||||||
let max_nnz = PROPTEST_MAX_NNZ;
|
let max_nnz = PROPTEST_MAX_NNZ;
|
||||||
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
||||||
let c_rows = PROPTEST_MATRIX_DIM;
|
let c_rows = PROPTEST_MATRIX_DIM;
|
||||||
|
@ -62,14 +64,23 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
|
||||||
let trans_strategy = trans_strategy();
|
let trans_strategy = trans_strategy();
|
||||||
let c_matrix_strategy = matrix(value_strategy.clone(), c_rows, c_cols);
|
let c_matrix_strategy = matrix(value_strategy.clone(), c_rows, c_cols);
|
||||||
|
|
||||||
(c_matrix_strategy, common_dim, trans_strategy.clone(), trans_strategy.clone())
|
(
|
||||||
|
c_matrix_strategy,
|
||||||
|
common_dim,
|
||||||
|
trans_strategy.clone(),
|
||||||
|
trans_strategy.clone(),
|
||||||
|
)
|
||||||
.prop_flat_map(move |(c, common_dim, trans_a, trans_b)| {
|
.prop_flat_map(move |(c, common_dim, trans_a, trans_b)| {
|
||||||
let a_shape =
|
let a_shape = if trans_a {
|
||||||
if trans_a { (common_dim, c.nrows()) }
|
(common_dim, c.nrows())
|
||||||
else { (c.nrows(), common_dim) };
|
} else {
|
||||||
let b_shape =
|
(c.nrows(), common_dim)
|
||||||
if trans_b { (c.ncols(), common_dim) }
|
};
|
||||||
else { (common_dim, c.ncols()) };
|
let b_shape = if trans_b {
|
||||||
|
(c.ncols(), common_dim)
|
||||||
|
} else {
|
||||||
|
(common_dim, c.ncols())
|
||||||
|
};
|
||||||
let a = csr(value_strategy.clone(), a_shape.0, a_shape.1, max_nnz);
|
let a = csr(value_strategy.clone(), a_shape.0, a_shape.1, max_nnz);
|
||||||
let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1);
|
let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1);
|
||||||
|
|
||||||
|
@ -78,30 +89,36 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
|
||||||
let beta = value_strategy.clone();
|
let beta = value_strategy.clone();
|
||||||
|
|
||||||
(Just(c), beta, alpha, Just(trans_a), a, Just(trans_b), b)
|
(Just(c), beta, alpha, Just(trans_a), a, Just(trans_b), b)
|
||||||
}).prop_map(|(c, beta, alpha, trans_a, a, trans_b, b)| {
|
})
|
||||||
SpmmCsrDenseArgs {
|
.prop_map(
|
||||||
|
|(c, beta, alpha, trans_a, a, trans_b, b)| SpmmCsrDenseArgs {
|
||||||
c,
|
c,
|
||||||
beta,
|
beta,
|
||||||
alpha,
|
alpha,
|
||||||
a: if trans_a { Op::Transpose(a) } else { Op::NoOp(a) },
|
a: if trans_a {
|
||||||
b: if trans_b { Op::Transpose(b) } else { Op::NoOp(b) },
|
Op::Transpose(a)
|
||||||
}
|
} else {
|
||||||
})
|
Op::NoOp(a)
|
||||||
|
},
|
||||||
|
b: if trans_b {
|
||||||
|
Op::Transpose(b)
|
||||||
|
} else {
|
||||||
|
Op::NoOp(b)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns matrices C, A and B with compatible dimensions such that it can be used
|
/// Returns matrices C, A and B with compatible dimensions such that it can be used
|
||||||
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
|
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
|
||||||
fn spmm_csc_dense_args_strategy() -> impl Strategy<Value=SpmmCscDenseArgs<i32>> {
|
fn spmm_csc_dense_args_strategy() -> impl Strategy<Value = SpmmCscDenseArgs<i32>> {
|
||||||
spmm_csr_dense_args_strategy()
|
spmm_csr_dense_args_strategy().prop_map(|args| SpmmCscDenseArgs {
|
||||||
.prop_map(|args| {
|
c: args.c,
|
||||||
SpmmCscDenseArgs {
|
beta: args.beta,
|
||||||
c: args.c,
|
alpha: args.alpha,
|
||||||
beta: args.beta,
|
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
|
||||||
alpha: args.alpha,
|
b: args.b,
|
||||||
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
|
})
|
||||||
b: args.b
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -120,7 +137,7 @@ struct SpaddCscArgs<T> {
|
||||||
a: Op<CscMatrix<T>>,
|
a: Op<CscMatrix<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value = SpaddCsrArgs<i32>> {
|
||||||
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
||||||
|
|
||||||
spadd_pattern_strategy()
|
spadd_pattern_strategy()
|
||||||
|
@ -131,66 +148,83 @@ fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>>
|
||||||
let c_values = vec![value_strategy.clone(); c_pattern.nnz()];
|
let c_values = vec![value_strategy.clone(); c_pattern.nnz()];
|
||||||
let alpha = value_strategy.clone();
|
let alpha = value_strategy.clone();
|
||||||
let beta = value_strategy.clone();
|
let beta = value_strategy.clone();
|
||||||
(Just(c_pattern), Just(a_pattern), c_values, a_values, alpha, beta, trans_strategy())
|
(
|
||||||
}).prop_map(|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| {
|
Just(c_pattern),
|
||||||
let c = CsrMatrix::try_from_pattern_and_values(c_pattern, c_values).unwrap();
|
Just(a_pattern),
|
||||||
let a = CsrMatrix::try_from_pattern_and_values(a_pattern, a_values).unwrap();
|
c_values,
|
||||||
|
a_values,
|
||||||
let a = if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) };
|
alpha,
|
||||||
SpaddCsrArgs { c, beta, alpha, a }
|
beta,
|
||||||
|
trans_strategy(),
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
.prop_map(
|
||||||
|
|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| {
|
||||||
|
let c = CsrMatrix::try_from_pattern_and_values(c_pattern, c_values).unwrap();
|
||||||
|
let a = CsrMatrix::try_from_pattern_and_values(a_pattern, a_values).unwrap();
|
||||||
|
|
||||||
|
let a = if trans_a {
|
||||||
|
Op::Transpose(a.transpose())
|
||||||
|
} else {
|
||||||
|
Op::NoOp(a)
|
||||||
|
};
|
||||||
|
SpaddCsrArgs { c, beta, alpha, a }
|
||||||
|
},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spadd_csc_prealloc_args_strategy() -> impl Strategy<Value=SpaddCscArgs<i32>> {
|
fn spadd_csc_prealloc_args_strategy() -> impl Strategy<Value = SpaddCscArgs<i32>> {
|
||||||
spadd_csr_prealloc_args_strategy()
|
spadd_csr_prealloc_args_strategy().prop_map(|args| SpaddCscArgs {
|
||||||
.prop_map(|args| SpaddCscArgs {
|
c: CscMatrix::from(&args.c),
|
||||||
c: CscMatrix::from(&args.c),
|
beta: args.beta,
|
||||||
beta: args.beta,
|
alpha: args.alpha,
|
||||||
alpha: args.alpha,
|
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
|
||||||
a: args.a.map_same_op(|a| CscMatrix::from(&a))
|
})
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
fn dense_strategy() -> impl Strategy<Value = DMatrix<i32>> {
|
||||||
matrix(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM)
|
matrix(
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
PROPTEST_MATRIX_DIM,
|
||||||
|
PROPTEST_MATRIX_DIM,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn trans_strategy() -> impl Strategy<Value=bool> + Clone {
|
fn trans_strategy() -> impl Strategy<Value = bool> + Clone {
|
||||||
proptest::bool::ANY
|
proptest::bool::ANY
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wraps the values of the given strategy in `Op`, producing both transposed and non-transposed
|
/// Wraps the values of the given strategy in `Op`, producing both transposed and non-transposed
|
||||||
/// values.
|
/// values.
|
||||||
fn op_strategy<S: Strategy>(strategy: S) -> impl Strategy<Value=Op<S::Value>> {
|
fn op_strategy<S: Strategy>(strategy: S) -> impl Strategy<Value = Op<S::Value>> {
|
||||||
let is_transposed = proptest::bool::ANY;
|
let is_transposed = proptest::bool::ANY;
|
||||||
(strategy, is_transposed)
|
(strategy, is_transposed).prop_map(|(obj, is_trans)| {
|
||||||
.prop_map(|(obj, is_trans)| if is_trans {
|
if is_trans {
|
||||||
Op::Transpose(obj)
|
Op::Transpose(obj)
|
||||||
} else {
|
} else {
|
||||||
Op::NoOp(obj)
|
Op::NoOp(obj)
|
||||||
})
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pattern_strategy() -> impl Strategy<Value=SparsityPattern> {
|
fn pattern_strategy() -> impl Strategy<Value = SparsityPattern> {
|
||||||
sparsity_pattern(PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
sparsity_pattern(PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Constructs pairs (a, b) where a and b have the same dimensions
|
/// Constructs pairs (a, b) where a and b have the same dimensions
|
||||||
fn spadd_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
fn spadd_pattern_strategy() -> impl Strategy<Value = (SparsityPattern, SparsityPattern)> {
|
||||||
pattern_strategy()
|
pattern_strategy().prop_flat_map(|a| {
|
||||||
.prop_flat_map(|a| {
|
let b = sparsity_pattern(a.major_dim(), a.minor_dim(), PROPTEST_MAX_NNZ);
|
||||||
let b = sparsity_pattern(a.major_dim(), a.minor_dim(), PROPTEST_MAX_NNZ);
|
(Just(a), b)
|
||||||
(Just(a), b)
|
})
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Constructs pairs (a, b) where a and b have compatible dimensions for a matrix product
|
/// Constructs pairs (a, b) where a and b have compatible dimensions for a matrix product
|
||||||
fn spmm_csr_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
fn spmm_csr_pattern_strategy() -> impl Strategy<Value = (SparsityPattern, SparsityPattern)> {
|
||||||
pattern_strategy()
|
pattern_strategy().prop_flat_map(|a| {
|
||||||
.prop_flat_map(|a| {
|
let b = sparsity_pattern(a.minor_dim(), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
|
||||||
let b = sparsity_pattern(a.minor_dim(), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
|
(Just(a), b)
|
||||||
(Just(a), b)
|
})
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -211,86 +245,98 @@ struct SpmmCscArgs<T> {
|
||||||
b: Op<CscMatrix<T>>,
|
b: Op<CscMatrix<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value = SpmmCsrArgs<i32>> {
|
||||||
spmm_csr_pattern_strategy()
|
spmm_csr_pattern_strategy()
|
||||||
.prop_flat_map(|(a_pattern, b_pattern)| {
|
.prop_flat_map(|(a_pattern, b_pattern)| {
|
||||||
let a_values = vec![PROPTEST_I32_VALUE_STRATEGY; a_pattern.nnz()];
|
let a_values = vec![PROPTEST_I32_VALUE_STRATEGY; a_pattern.nnz()];
|
||||||
let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()];
|
let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()];
|
||||||
let c_pattern = spmm_csr_pattern(&a_pattern, &b_pattern);
|
let c_pattern = spmm_csr_pattern(&a_pattern, &b_pattern);
|
||||||
let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()];
|
let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()];
|
||||||
let a = a_values.prop_map(move |values|
|
let a = a_values.prop_map(move |values| {
|
||||||
CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap());
|
CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap()
|
||||||
let b = b_values.prop_map(move |values|
|
});
|
||||||
CsrMatrix::try_from_pattern_and_values(b_pattern.clone(), values).unwrap());
|
let b = b_values.prop_map(move |values| {
|
||||||
let c = c_values.prop_map(move |values|
|
CsrMatrix::try_from_pattern_and_values(b_pattern.clone(), values).unwrap()
|
||||||
CsrMatrix::try_from_pattern_and_values(c_pattern.clone(), values).unwrap());
|
});
|
||||||
|
let c = c_values.prop_map(move |values| {
|
||||||
|
CsrMatrix::try_from_pattern_and_values(c_pattern.clone(), values).unwrap()
|
||||||
|
});
|
||||||
let alpha = PROPTEST_I32_VALUE_STRATEGY;
|
let alpha = PROPTEST_I32_VALUE_STRATEGY;
|
||||||
let beta = PROPTEST_I32_VALUE_STRATEGY;
|
let beta = PROPTEST_I32_VALUE_STRATEGY;
|
||||||
(c, beta, alpha, trans_strategy(), a, trans_strategy(), b)
|
(c, beta, alpha, trans_strategy(), a, trans_strategy(), b)
|
||||||
})
|
})
|
||||||
.prop_map(|(c, beta, alpha, trans_a, a, trans_b, b)| {
|
.prop_map(
|
||||||
SpmmCsrArgs::<i32> {
|
|(c, beta, alpha, trans_a, a, trans_b, b)| SpmmCsrArgs::<i32> {
|
||||||
c,
|
c,
|
||||||
beta,
|
beta,
|
||||||
alpha,
|
alpha,
|
||||||
a: if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) },
|
a: if trans_a {
|
||||||
b: if trans_b { Op::Transpose(b.transpose()) } else { Op::NoOp(b) }
|
Op::Transpose(a.transpose())
|
||||||
}
|
} else {
|
||||||
})
|
Op::NoOp(a)
|
||||||
|
},
|
||||||
|
b: if trans_b {
|
||||||
|
Op::Transpose(b.transpose())
|
||||||
|
} else {
|
||||||
|
Op::NoOp(b)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spmm_csc_prealloc_args_strategy() -> impl Strategy<Value=SpmmCscArgs<i32>> {
|
fn spmm_csc_prealloc_args_strategy() -> impl Strategy<Value = SpmmCscArgs<i32>> {
|
||||||
// Note: Converting from CSR is simple, but might be significantly slower than
|
// Note: Converting from CSR is simple, but might be significantly slower than
|
||||||
// writing a common implementation that can be shared between CSR and CSC args
|
// writing a common implementation that can be shared between CSR and CSC args
|
||||||
spmm_csr_prealloc_args_strategy()
|
spmm_csr_prealloc_args_strategy().prop_map(|args| SpmmCscArgs {
|
||||||
.prop_map(|args| {
|
c: CscMatrix::from(&args.c),
|
||||||
SpmmCscArgs {
|
beta: args.beta,
|
||||||
c: CscMatrix::from(&args.c),
|
alpha: args.alpha,
|
||||||
beta: args.beta,
|
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
|
||||||
alpha: args.alpha,
|
b: args.b.map_same_op(|b| CscMatrix::from(&b)),
|
||||||
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
|
})
|
||||||
b: args.b.map_same_op(|b| CscMatrix::from(&b))
|
}
|
||||||
|
|
||||||
|
fn csc_invertible_diagonal() -> impl Strategy<Value = CscMatrix<f64>> {
|
||||||
|
let non_zero_values =
|
||||||
|
value_strategy::<f64>().prop_filter("Only non-zeros values accepted", |x| x != &0.0);
|
||||||
|
|
||||||
|
vector(non_zero_values, PROPTEST_MATRIX_DIM).prop_map(|d| {
|
||||||
|
let mut matrix = CscMatrix::identity(d.len());
|
||||||
|
matrix.values_mut().clone_from_slice(&d.as_slice());
|
||||||
|
matrix
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn csc_square_with_non_zero_diagonals() -> impl Strategy<Value = CscMatrix<f64>> {
|
||||||
|
csc_invertible_diagonal().prop_flat_map(|d| {
|
||||||
|
csc(
|
||||||
|
value_strategy::<f64>(),
|
||||||
|
d.nrows(),
|
||||||
|
d.nrows(),
|
||||||
|
PROPTEST_MAX_NNZ,
|
||||||
|
)
|
||||||
|
.prop_map(move |mut c| {
|
||||||
|
for (i, j, v) in c.triplet_iter_mut() {
|
||||||
|
if i == j {
|
||||||
|
*v = 0.0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the sum of a matrix with zero diagonals and an invertible diagonal
|
||||||
|
// matrix
|
||||||
|
c + &d
|
||||||
})
|
})
|
||||||
}
|
})
|
||||||
|
|
||||||
fn csc_invertible_diagonal() -> impl Strategy<Value=CscMatrix<f64>> {
|
|
||||||
let non_zero_values = value_strategy::<f64>()
|
|
||||||
.prop_filter("Only non-zeros values accepted", |x| x != &0.0);
|
|
||||||
|
|
||||||
vector(non_zero_values, PROPTEST_MATRIX_DIM)
|
|
||||||
.prop_map(|d| {
|
|
||||||
let mut matrix = CscMatrix::identity(d.len());
|
|
||||||
matrix.values_mut().clone_from_slice(&d.as_slice());
|
|
||||||
matrix
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn csc_square_with_non_zero_diagonals() -> impl Strategy<Value=CscMatrix<f64>> {
|
|
||||||
csc_invertible_diagonal()
|
|
||||||
.prop_flat_map(|d| {
|
|
||||||
csc(value_strategy::<f64>(), d.nrows(), d.nrows(), PROPTEST_MAX_NNZ)
|
|
||||||
.prop_map(move |mut c| {
|
|
||||||
for (i, j, v) in c.triplet_iter_mut() {
|
|
||||||
if i == j {
|
|
||||||
*v = 0.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the sum of a matrix with zero diagonals and an invertible diagonal
|
|
||||||
// matrix
|
|
||||||
c + &d
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper function to help us call dense GEMM with our `Op` type
|
/// Helper function to help us call dense GEMM with our `Op` type
|
||||||
fn dense_gemm<'a>(beta: i32,
|
fn dense_gemm<'a>(
|
||||||
c: impl Into<DMatrixSliceMut<'a, i32>>,
|
beta: i32,
|
||||||
alpha: i32,
|
c: impl Into<DMatrixSliceMut<'a, i32>>,
|
||||||
a: Op<impl Into<DMatrixSlice<'a, i32>>>,
|
alpha: i32,
|
||||||
b: Op<impl Into<DMatrixSlice<'a, i32>>>)
|
a: Op<impl Into<DMatrixSlice<'a, i32>>>,
|
||||||
{
|
b: Op<impl Into<DMatrixSlice<'a, i32>>>,
|
||||||
|
) {
|
||||||
let mut c = c.into();
|
let mut c = c.into();
|
||||||
let a = a.convert();
|
let a = a.convert();
|
||||||
let b = b.convert();
|
let b = b.convert();
|
||||||
|
@ -300,7 +346,7 @@ fn dense_gemm<'a>(beta: i32,
|
||||||
(NoOp(a), NoOp(b)) => c.gemm(alpha, &a, &b, beta),
|
(NoOp(a), NoOp(b)) => c.gemm(alpha, &a, &b, beta),
|
||||||
(Transpose(a), NoOp(b)) => c.gemm(alpha, &a.transpose(), &b, beta),
|
(Transpose(a), NoOp(b)) => c.gemm(alpha, &a.transpose(), &b, beta),
|
||||||
(NoOp(a), Transpose(b)) => c.gemm(alpha, &a, &b.transpose(), beta),
|
(NoOp(a), Transpose(b)) => c.gemm(alpha, &a, &b.transpose(), beta),
|
||||||
(Transpose(a), Transpose(b)) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta)
|
(Transpose(a), Transpose(b)) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1186,4 +1232,4 @@ proptest! {
|
||||||
prop_assert_matrix_eq!(&a_lower.transpose() * &x, &b, comp = abs, tol = 1e-4);
|
prop_assert_matrix_eq!(&a_lower.transpose() * &x, &b, comp = abs, tol = 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,11 +7,9 @@ fn sparsity_pattern_valid_data() {
|
||||||
|
|
||||||
{
|
{
|
||||||
// A pattern with zero explicitly stored entries
|
// A pattern with zero explicitly stored entries
|
||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3,
|
let pattern =
|
||||||
2,
|
SparsityPattern::try_from_offsets_and_indices(3, 2, vec![0, 0, 0, 0], Vec::new())
|
||||||
vec![0, 0, 0, 0],
|
.unwrap();
|
||||||
Vec::new())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(pattern.major_dim(), 3);
|
assert_eq!(pattern.major_dim(), 3);
|
||||||
assert_eq!(pattern.minor_dim(), 2);
|
assert_eq!(pattern.minor_dim(), 2);
|
||||||
|
@ -36,7 +34,7 @@ fn sparsity_pattern_valid_data() {
|
||||||
let indices = vec![0, 5, 1, 2, 3];
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
let pattern =
|
let pattern =
|
||||||
SparsityPattern::try_from_offsets_and_indices(3, 6, offsets.clone(), indices.clone())
|
SparsityPattern::try_from_offsets_and_indices(3, 6, offsets.clone(), indices.clone())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(pattern.major_dim(), 3);
|
assert_eq!(pattern.major_dim(), 3);
|
||||||
assert_eq!(pattern.minor_dim(), 6);
|
assert_eq!(pattern.minor_dim(), 6);
|
||||||
|
@ -46,8 +44,10 @@ fn sparsity_pattern_valid_data() {
|
||||||
assert_eq!(pattern.lane(0), &[0, 5]);
|
assert_eq!(pattern.lane(0), &[0, 5]);
|
||||||
assert_eq!(pattern.lane(1), &[]);
|
assert_eq!(pattern.lane(1), &[]);
|
||||||
assert_eq!(pattern.lane(2), &[1, 2, 3]);
|
assert_eq!(pattern.lane(2), &[1, 2, 3]);
|
||||||
assert_eq!(pattern.entries().collect::<Vec<_>>(),
|
assert_eq!(
|
||||||
vec![(0, 0), (0, 5), (2, 1), (2, 2), (2, 3)]);
|
pattern.entries().collect::<Vec<_>>(),
|
||||||
|
vec![(0, 0), (0, 5), (2, 1), (2, 2), (2, 3)]
|
||||||
|
);
|
||||||
|
|
||||||
let (offsets2, indices2) = pattern.disassemble();
|
let (offsets2, indices2) = pattern.disassemble();
|
||||||
assert_eq!(offsets2, offsets);
|
assert_eq!(offsets2, offsets);
|
||||||
|
@ -60,7 +60,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
||||||
{
|
{
|
||||||
// Empty offset array (invalid length)
|
// Empty offset array (invalid length)
|
||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(0, 0, Vec::new(), Vec::new());
|
let pattern = SparsityPattern::try_from_offsets_and_indices(0, 0, Vec::new(), Vec::new());
|
||||||
assert_eq!(pattern, Err(SparsityPatternFormatError::InvalidOffsetArrayLength));
|
assert_eq!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -69,7 +72,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
||||||
let indices = vec![0, 1, 2, 3, 5];
|
let indices = vec![0, 1, 2, 3, 5];
|
||||||
|
|
||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetArrayLength)));
|
assert!(matches!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -77,7 +83,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
||||||
let offsets = vec![1, 2, 2, 5];
|
let offsets = vec![1, 2, 2, 5];
|
||||||
let indices = vec![0, 5, 1, 2, 3];
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetFirstLast)));
|
assert!(matches!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::InvalidOffsetFirstLast)
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -85,7 +94,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
||||||
let offsets = vec![0, 2, 2, 4];
|
let offsets = vec![0, 2, 2, 4];
|
||||||
let indices = vec![0, 5, 1, 2, 3];
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetFirstLast)));
|
assert!(matches!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::InvalidOffsetFirstLast)
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -93,7 +105,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
||||||
let offsets = vec![0, 2, 2];
|
let offsets = vec![0, 2, 2];
|
||||||
let indices = vec![0, 5, 1, 2, 3];
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetArrayLength)));
|
assert!(matches!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -101,7 +116,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
||||||
let offsets = vec![0, 3, 2, 5];
|
let offsets = vec![0, 3, 2, 5];
|
||||||
let indices = vec![0, 1, 2, 3, 4];
|
let indices = vec![0, 1, 2, 3, 4];
|
||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
assert_eq!(pattern, Err(SparsityPatternFormatError::NonmonotonicOffsets));
|
assert_eq!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::NonmonotonicOffsets)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -109,7 +127,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
||||||
let offsets = vec![0, 2, 2, 5];
|
let offsets = vec![0, 2, 2, 5];
|
||||||
let indices = vec![0, 2, 3, 1, 4];
|
let indices = vec![0, 2, 3, 1, 4];
|
||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
assert_eq!(pattern, Err(SparsityPatternFormatError::NonmonotonicMinorIndices));
|
assert_eq!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::NonmonotonicMinorIndices)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -117,7 +138,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
||||||
let offsets = vec![0, 2, 2, 5];
|
let offsets = vec![0, 2, 2, 5];
|
||||||
let indices = vec![0, 6, 1, 2, 3];
|
let indices = vec![0, 6, 1, 2, 3];
|
||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
assert_eq!(pattern, Err(SparsityPatternFormatError::MinorIndexOutOfBounds));
|
assert_eq!(
|
||||||
|
pattern,
|
||||||
|
Err(SparsityPatternFormatError::MinorIndexOutOfBounds)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -127,4 +151,4 @@ fn sparsity_pattern_try_from_invalid_data() {
|
||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||||
assert_eq!(pattern, Err(SparsityPatternFormatError::DuplicateEntry));
|
assert_eq!(pattern, Err(SparsityPatternFormatError::DuplicateEntry));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,25 +6,27 @@ fn coo_no_duplicates_generates_admissible_matrices() {
|
||||||
|
|
||||||
#[cfg(feature = "slow-tests")]
|
#[cfg(feature = "slow-tests")]
|
||||||
mod slow {
|
mod slow {
|
||||||
use nalgebra_sparse::proptest::{coo_with_duplicates, coo_no_duplicates, csr, csc, sparsity_pattern};
|
|
||||||
use nalgebra::DMatrix;
|
use nalgebra::DMatrix;
|
||||||
|
use nalgebra_sparse::proptest::{
|
||||||
|
coo_no_duplicates, coo_with_duplicates, csc, csr, sparsity_pattern,
|
||||||
|
};
|
||||||
|
|
||||||
use proptest::test_runner::TestRunner;
|
|
||||||
use proptest::strategy::ValueTree;
|
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
use proptest::strategy::ValueTree;
|
||||||
|
use proptest::test_runner::TestRunner;
|
||||||
|
|
||||||
use proptest::prelude::*;
|
use proptest::prelude::*;
|
||||||
|
|
||||||
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::iter::repeat;
|
use std::iter::repeat;
|
||||||
use std::ops::RangeInclusive;
|
use std::ops::RangeInclusive;
|
||||||
use nalgebra_sparse::csr::CsrMatrix;
|
|
||||||
|
|
||||||
fn generate_all_possible_matrices(value_range: RangeInclusive<i32>,
|
fn generate_all_possible_matrices(
|
||||||
rows_range: RangeInclusive<usize>,
|
value_range: RangeInclusive<i32>,
|
||||||
cols_range: RangeInclusive<usize>)
|
rows_range: RangeInclusive<usize>,
|
||||||
-> HashSet<DMatrix<i32>>
|
cols_range: RangeInclusive<usize>,
|
||||||
{
|
) -> HashSet<DMatrix<i32>> {
|
||||||
// Enumerate all possible combinations
|
// Enumerate all possible combinations
|
||||||
let mut all_combinations = HashSet::new();
|
let mut all_combinations = HashSet::new();
|
||||||
for nrows in rows_range {
|
for nrows in rows_range {
|
||||||
|
@ -48,7 +50,11 @@ mod slow {
|
||||||
.take(n_values)
|
.take(n_values)
|
||||||
.multi_cartesian_product();
|
.multi_cartesian_product();
|
||||||
for matrix_values in values_iter {
|
for matrix_values in values_iter {
|
||||||
all_combinations.insert(DMatrix::from_row_slice(nrows, ncols, &matrix_values));
|
all_combinations.insert(DMatrix::from_row_slice(
|
||||||
|
nrows,
|
||||||
|
ncols,
|
||||||
|
&matrix_values,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -80,12 +86,14 @@ mod slow {
|
||||||
// Enumerate all possible combinations
|
// Enumerate all possible combinations
|
||||||
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||||
|
|
||||||
let visited_combinations = sample_matrix_output_space(strategy,
|
let visited_combinations =
|
||||||
&mut runner,
|
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||||
num_generated_matrices);
|
|
||||||
|
|
||||||
assert_eq!(visited_combinations.len(), all_combinations.len());
|
assert_eq!(visited_combinations.len(), all_combinations.len());
|
||||||
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values.");
|
assert_eq!(
|
||||||
|
visited_combinations, all_combinations,
|
||||||
|
"Did not sample all possible values."
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "slow-tests")]
|
#[cfg(feature = "slow-tests")]
|
||||||
|
@ -113,9 +121,8 @@ mod slow {
|
||||||
// `coo_with_duplicates`)
|
// `coo_with_duplicates`)
|
||||||
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||||
|
|
||||||
let visited_combinations = sample_matrix_output_space(strategy,
|
let visited_combinations =
|
||||||
&mut runner,
|
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||||
num_generated_matrices);
|
|
||||||
|
|
||||||
// Here we cannot verify that the set of visited combinations is *equal* to
|
// Here we cannot verify that the set of visited combinations is *equal* to
|
||||||
// all possible outcomes with the given constraints, however the
|
// all possible outcomes with the given constraints, however the
|
||||||
|
@ -143,12 +150,14 @@ mod slow {
|
||||||
|
|
||||||
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||||
|
|
||||||
let visited_combinations = sample_matrix_output_space(strategy,
|
let visited_combinations =
|
||||||
&mut runner,
|
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||||
num_generated_matrices);
|
|
||||||
|
|
||||||
assert_eq!(visited_combinations.len(), all_combinations.len());
|
assert_eq!(visited_combinations.len(), all_combinations.len());
|
||||||
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values.");
|
assert_eq!(
|
||||||
|
visited_combinations, all_combinations,
|
||||||
|
"Did not sample all possible values."
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "slow-tests")]
|
#[cfg(feature = "slow-tests")]
|
||||||
|
@ -169,12 +178,14 @@ mod slow {
|
||||||
|
|
||||||
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||||
|
|
||||||
let visited_combinations = sample_matrix_output_space(strategy,
|
let visited_combinations =
|
||||||
&mut runner,
|
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||||
num_generated_matrices);
|
|
||||||
|
|
||||||
assert_eq!(visited_combinations.len(), all_combinations.len());
|
assert_eq!(visited_combinations.len(), all_combinations.len());
|
||||||
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values.");
|
assert_eq!(
|
||||||
|
visited_combinations, all_combinations,
|
||||||
|
"Did not sample all possible values."
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "slow-tests")]
|
#[cfg(feature = "slow-tests")]
|
||||||
|
@ -206,13 +217,14 @@ mod slow {
|
||||||
assert_eq!(visited_patterns, all_possible_patterns);
|
assert_eq!(visited_patterns, all_possible_patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample_matrix_output_space<S>(strategy: S,
|
fn sample_matrix_output_space<S>(
|
||||||
runner: &mut TestRunner,
|
strategy: S,
|
||||||
num_samples: usize)
|
runner: &mut TestRunner,
|
||||||
-> HashSet<DMatrix<i32>>
|
num_samples: usize,
|
||||||
|
) -> HashSet<DMatrix<i32>>
|
||||||
where
|
where
|
||||||
S: Strategy,
|
S: Strategy,
|
||||||
DMatrix<i32>: for<'b> From<&'b S::Value>
|
DMatrix<i32>: for<'b> From<&'b S::Value>,
|
||||||
{
|
{
|
||||||
sample_strategy(strategy, runner)
|
sample_strategy(strategy, runner)
|
||||||
.take(num_samples)
|
.take(num_samples)
|
||||||
|
@ -220,8 +232,10 @@ mod slow {
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample_strategy<'a, S: 'a + Strategy>(strategy: S, runner: &'a mut TestRunner)
|
fn sample_strategy<'a, S: 'a + Strategy>(
|
||||||
-> impl 'a + Iterator<Item=S::Value> {
|
strategy: S,
|
||||||
|
runner: &'a mut TestRunner,
|
||||||
|
) -> impl 'a + Iterator<Item = S::Value> {
|
||||||
repeat(()).map(move |_| {
|
repeat(()).map(move |_| {
|
||||||
let tree = strategy
|
let tree = strategy
|
||||||
.new_tree(runner)
|
.new_tree(runner)
|
||||||
|
@ -230,4 +244,4 @@ mod slow {
|
||||||
value
|
value
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue