This commit is contained in:
Andreas Longva 2021-01-25 17:26:27 +01:00
parent 795d818ae5
commit 7473d54d74
30 changed files with 1477 additions and 1069 deletions

View File

@ -1,17 +1,17 @@
use crate::coo::CooMatrix;
use crate::convert::serial::*; use crate::convert::serial::*;
use nalgebra::{Matrix, Scalar, Dim, ClosedAdd, DMatrix}; use crate::coo::CooMatrix;
use nalgebra::storage::{Storage};
use num_traits::Zero;
use crate::csr::CsrMatrix;
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use crate::csr::CsrMatrix;
use nalgebra::storage::Storage;
use nalgebra::{ClosedAdd, DMatrix, Dim, Matrix, Scalar};
use num_traits::Zero;
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CooMatrix<T> impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CooMatrix<T>
where where
T: Scalar + Zero, T: Scalar + Zero,
R: Dim, R: Dim,
C: Dim, C: Dim,
S: Storage<T, R, C> S: Storage<T, R, C>,
{ {
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self { fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
convert_dense_coo(matrix) convert_dense_coo(matrix)
@ -29,7 +29,7 @@ where
impl<'a, T> From<&'a CooMatrix<T>> for CsrMatrix<T> impl<'a, T> From<&'a CooMatrix<T>> for CsrMatrix<T>
where where
T: Scalar + Zero + ClosedAdd T: Scalar + Zero + ClosedAdd,
{ {
fn from(matrix: &'a CooMatrix<T>) -> Self { fn from(matrix: &'a CooMatrix<T>) -> Self {
convert_coo_csr(matrix) convert_coo_csr(matrix)
@ -38,7 +38,7 @@ where
impl<'a, T> From<&'a CsrMatrix<T>> for CooMatrix<T> impl<'a, T> From<&'a CsrMatrix<T>> for CooMatrix<T>
where where
T: Scalar + Zero + ClosedAdd T: Scalar + Zero + ClosedAdd,
{ {
fn from(matrix: &'a CsrMatrix<T>) -> Self { fn from(matrix: &'a CsrMatrix<T>) -> Self {
convert_csr_coo(matrix) convert_csr_coo(matrix)
@ -50,7 +50,7 @@ where
T: Scalar + Zero, T: Scalar + Zero,
R: Dim, R: Dim,
C: Dim, C: Dim,
S: Storage<T, R, C> S: Storage<T, R, C>,
{ {
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self { fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
convert_dense_csr(matrix) convert_dense_csr(matrix)
@ -59,7 +59,7 @@ where
impl<'a, T> From<&'a CsrMatrix<T>> for DMatrix<T> impl<'a, T> From<&'a CsrMatrix<T>> for DMatrix<T>
where where
T: Scalar + Zero + ClosedAdd T: Scalar + Zero + ClosedAdd,
{ {
fn from(matrix: &'a CsrMatrix<T>) -> Self { fn from(matrix: &'a CsrMatrix<T>) -> Self {
convert_csr_dense(matrix) convert_csr_dense(matrix)
@ -68,7 +68,7 @@ where
impl<'a, T> From<&'a CooMatrix<T>> for CscMatrix<T> impl<'a, T> From<&'a CooMatrix<T>> for CscMatrix<T>
where where
T: Scalar + Zero + ClosedAdd T: Scalar + Zero + ClosedAdd,
{ {
fn from(matrix: &'a CooMatrix<T>) -> Self { fn from(matrix: &'a CooMatrix<T>) -> Self {
convert_coo_csc(matrix) convert_coo_csc(matrix)
@ -77,7 +77,7 @@ where
impl<'a, T> From<&'a CscMatrix<T>> for CooMatrix<T> impl<'a, T> From<&'a CscMatrix<T>> for CooMatrix<T>
where where
T: Scalar + Zero T: Scalar + Zero,
{ {
fn from(matrix: &'a CscMatrix<T>) -> Self { fn from(matrix: &'a CscMatrix<T>) -> Self {
convert_csc_coo(matrix) convert_csc_coo(matrix)
@ -89,7 +89,7 @@ impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CscMatrix<T>
T: Scalar + Zero, T: Scalar + Zero,
R: Dim, R: Dim,
C: Dim, C: Dim,
S: Storage<T, R, C> S: Storage<T, R, C>,
{ {
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self { fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
convert_dense_csc(matrix) convert_dense_csc(matrix)
@ -98,7 +98,7 @@ impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CscMatrix<T>
impl<'a, T> From<&'a CscMatrix<T>> for DMatrix<T> impl<'a, T> From<&'a CscMatrix<T>> for DMatrix<T>
where where
T: Scalar + Zero + ClosedAdd T: Scalar + Zero + ClosedAdd,
{ {
fn from(matrix: &'a CscMatrix<T>) -> Self { fn from(matrix: &'a CscMatrix<T>) -> Self {
convert_csc_dense(matrix) convert_csc_dense(matrix)
@ -107,7 +107,7 @@ impl<'a, T> From<&'a CscMatrix<T>> for DMatrix<T>
impl<'a, T> From<&'a CscMatrix<T>> for CsrMatrix<T> impl<'a, T> From<&'a CscMatrix<T>> for CsrMatrix<T>
where where
T: Scalar T: Scalar,
{ {
fn from(matrix: &'a CscMatrix<T>) -> Self { fn from(matrix: &'a CscMatrix<T>) -> Self {
convert_csc_csr(matrix) convert_csc_csr(matrix)
@ -116,7 +116,7 @@ impl<'a, T> From<&'a CscMatrix<T>> for CsrMatrix<T>
impl<'a, T> From<&'a CsrMatrix<T>> for CscMatrix<T> impl<'a, T> From<&'a CsrMatrix<T>> for CscMatrix<T>
where where
T: Scalar T: Scalar,
{ {
fn from(matrix: &'a CsrMatrix<T>) -> Self { fn from(matrix: &'a CsrMatrix<T>) -> Self {
convert_csr_csc(matrix) convert_csr_csc(matrix)

View File

@ -7,8 +7,8 @@ use std::ops::Add;
use num_traits::Zero; use num_traits::Zero;
use nalgebra::{ClosedAdd, Dim, DMatrix, Matrix, Scalar};
use nalgebra::storage::Storage; use nalgebra::storage::Storage;
use nalgebra::{ClosedAdd, DMatrix, Dim, Matrix, Scalar};
use crate::coo::CooMatrix; use crate::coo::CooMatrix;
use crate::cs; use crate::cs;
@ -21,7 +21,7 @@ where
T: Scalar + Zero, T: Scalar + Zero,
R: Dim, R: Dim,
C: Dim, C: Dim,
S: Storage<T, R, C> S: Storage<T, R, C>,
{ {
let mut coo = CooMatrix::new(dense.nrows(), dense.ncols()); let mut coo = CooMatrix::new(dense.nrows(), dense.ncols());
@ -52,12 +52,14 @@ where
/// Converts a [`CooMatrix`] to a [`CsrMatrix`]. /// Converts a [`CooMatrix`] to a [`CsrMatrix`].
pub fn convert_coo_csr<T>(coo: &CooMatrix<T>) -> CsrMatrix<T> pub fn convert_coo_csr<T>(coo: &CooMatrix<T>) -> CsrMatrix<T>
where where
T: Scalar + Zero T: Scalar + Zero,
{ {
let (offsets, indices, values) = convert_coo_cs(coo.nrows(), let (offsets, indices, values) = convert_coo_cs(
coo.nrows(),
coo.row_indices(), coo.row_indices(),
coo.col_indices(), coo.col_indices(),
coo.values()); coo.values(),
);
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark // TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
// to see if it can be justified for performance reasons) // to see if it can be justified for performance reasons)
@ -66,8 +68,7 @@ where
} }
/// Converts a [`CsrMatrix`] to a [`CooMatrix`]. /// Converts a [`CsrMatrix`] to a [`CooMatrix`].
pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T> pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T> {
{
let mut result = CooMatrix::new(csr.nrows(), csr.ncols()); let mut result = CooMatrix::new(csr.nrows(), csr.ncols());
for (i, j, v) in csr.triplet_iter() { for (i, j, v) in csr.triplet_iter() {
result.push(i, j, v.inlined_clone()); result.push(i, j, v.inlined_clone());
@ -78,7 +79,7 @@ pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T>
/// Converts a [`CsrMatrix`] to a dense matrix. /// Converts a [`CsrMatrix`] to a dense matrix.
pub fn convert_csr_dense<T>(csr: &CsrMatrix<T>) -> DMatrix<T> pub fn convert_csr_dense<T>(csr: &CsrMatrix<T>) -> DMatrix<T>
where where
T: Scalar + ClosedAdd + Zero T: Scalar + ClosedAdd + Zero,
{ {
let mut output = DMatrix::zeros(csr.nrows(), csr.ncols()); let mut output = DMatrix::zeros(csr.nrows(), csr.ncols());
@ -95,7 +96,7 @@ where
T: Scalar + Zero, T: Scalar + Zero,
R: Dim, R: Dim,
C: Dim, C: Dim,
S: Storage<T, R, C> S: Storage<T, R, C>,
{ {
let mut row_offsets = Vec::with_capacity(dense.nrows() + 1); let mut row_offsets = Vec::with_capacity(dense.nrows() + 1);
let mut col_idx = Vec::new(); let mut col_idx = Vec::new();
@ -125,12 +126,14 @@ where
/// Converts a [`CooMatrix`] to a [`CscMatrix`]. /// Converts a [`CooMatrix`] to a [`CscMatrix`].
pub fn convert_coo_csc<T>(coo: &CooMatrix<T>) -> CscMatrix<T> pub fn convert_coo_csc<T>(coo: &CooMatrix<T>) -> CscMatrix<T>
where where
T: Scalar + Zero T: Scalar + Zero,
{ {
let (offsets, indices, values) = convert_coo_cs(coo.ncols(), let (offsets, indices, values) = convert_coo_cs(
coo.ncols(),
coo.col_indices(), coo.col_indices(),
coo.row_indices(), coo.row_indices(),
coo.values()); coo.values(),
);
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark // TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
// to see if it can be justified for performance reasons) // to see if it can be justified for performance reasons)
@ -141,7 +144,7 @@ where
/// Converts a [`CscMatrix`] to a [`CooMatrix`]. /// Converts a [`CscMatrix`] to a [`CooMatrix`].
pub fn convert_csc_coo<T>(csc: &CscMatrix<T>) -> CooMatrix<T> pub fn convert_csc_coo<T>(csc: &CscMatrix<T>) -> CooMatrix<T>
where where
T: Scalar T: Scalar,
{ {
let mut coo = CooMatrix::new(csc.nrows(), csc.ncols()); let mut coo = CooMatrix::new(csc.nrows(), csc.ncols());
for (i, j, v) in csc.triplet_iter() { for (i, j, v) in csc.triplet_iter() {
@ -153,7 +156,7 @@ where
/// Converts a [`CscMatrix`] to a dense matrix. /// Converts a [`CscMatrix`] to a dense matrix.
pub fn convert_csc_dense<T>(csc: &CscMatrix<T>) -> DMatrix<T> pub fn convert_csc_dense<T>(csc: &CscMatrix<T>) -> DMatrix<T>
where where
T: Scalar + ClosedAdd + Zero T: Scalar + ClosedAdd + Zero,
{ {
let mut output = DMatrix::zeros(csc.nrows(), csc.ncols()); let mut output = DMatrix::zeros(csc.nrows(), csc.ncols());
@ -170,7 +173,7 @@ pub fn convert_dense_csc<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CscMatrix<T>
T: Scalar + Zero, T: Scalar + Zero,
R: Dim, R: Dim,
C: Dim, C: Dim,
S: Storage<T, R, C> S: Storage<T, R, C>,
{ {
let mut col_offsets = Vec::with_capacity(dense.ncols() + 1); let mut col_offsets = Vec::with_capacity(dense.ncols() + 1);
let mut row_idx = Vec::new(); let mut row_idx = Vec::new();
@ -197,13 +200,15 @@ pub fn convert_dense_csc<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CscMatrix<T>
/// Converts a [`CsrMatrix`] to a [`CscMatrix`]. /// Converts a [`CsrMatrix`] to a [`CscMatrix`].
pub fn convert_csr_csc<T>(csr: &CsrMatrix<T>) -> CscMatrix<T> pub fn convert_csr_csc<T>(csr: &CsrMatrix<T>) -> CscMatrix<T>
where where
T: Scalar T: Scalar,
{ {
let (offsets, indices, values) = cs::transpose_cs(csr.nrows(), let (offsets, indices, values) = cs::transpose_cs(
csr.nrows(),
csr.ncols(), csr.ncols(),
csr.row_offsets(), csr.row_offsets(),
csr.col_indices(), csr.col_indices(),
csr.values()); csr.values(),
);
// TODO: Avoid data validity check? // TODO: Avoid data validity check?
CscMatrix::try_from_csc_data(csr.nrows(), csr.ncols(), offsets, indices, values) CscMatrix::try_from_csc_data(csr.nrows(), csr.ncols(), offsets, indices, values)
@ -213,26 +218,29 @@ where
/// Converts a [`CscMatrix`] to a [`CsrMatrix`]. /// Converts a [`CscMatrix`] to a [`CsrMatrix`].
pub fn convert_csc_csr<T>(csc: &CscMatrix<T>) -> CsrMatrix<T> pub fn convert_csc_csr<T>(csc: &CscMatrix<T>) -> CsrMatrix<T>
where where
T: Scalar T: Scalar,
{ {
let (offsets, indices, values) = cs::transpose_cs(csc.ncols(), let (offsets, indices, values) = cs::transpose_cs(
csc.ncols(),
csc.nrows(), csc.nrows(),
csc.col_offsets(), csc.col_offsets(),
csc.row_indices(), csc.row_indices(),
csc.values()); csc.values(),
);
// TODO: Avoid data validity check? // TODO: Avoid data validity check?
CsrMatrix::try_from_csr_data(csc.nrows(), csc.ncols(), offsets, indices, values) CsrMatrix::try_from_csr_data(csc.nrows(), csc.ncols(), offsets, indices, values)
.expect("Internal error: Invalid CSR data during CSC->CSR conversion") .expect("Internal error: Invalid CSR data during CSC->CSR conversion")
} }
fn convert_coo_cs<T>(major_dim: usize, fn convert_coo_cs<T>(
major_dim: usize,
major_indices: &[usize], major_indices: &[usize],
minor_indices: &[usize], minor_indices: &[usize],
values: &[T]) values: &[T],
-> (Vec<usize>, Vec<usize>, Vec<T>) ) -> (Vec<usize>, Vec<usize>, Vec<T>)
where where
T: Scalar + Zero T: Scalar + Zero,
{ {
assert_eq!(major_indices.len(), minor_indices.len()); assert_eq!(major_indices.len(), minor_indices.len());
assert_eq!(minor_indices.len(), values.len()); assert_eq!(minor_indices.len(), values.len());

View File

@ -45,8 +45,7 @@ pub struct CooMatrix<T> {
values: Vec<T>, values: Vec<T>,
} }
impl<T> CooMatrix<T> impl<T> CooMatrix<T> {
{
/// Construct a zero COO matrix of the given dimensions. /// Construct a zero COO matrix of the given dimensions.
/// ///
/// Specifically, the collection of triplets - corresponding to explicitly stored entries - /// Specifically, the collection of triplets - corresponding to explicitly stored entries -
@ -78,11 +77,13 @@ impl<T> CooMatrix<T>
use crate::SparseFormatErrorKind::*; use crate::SparseFormatErrorKind::*;
if row_indices.len() != col_indices.len() { if row_indices.len() != col_indices.len() {
return Err(SparseFormatError::from_kind_and_msg( return Err(SparseFormatError::from_kind_and_msg(
InvalidStructure, "Number of row and col indices must be the same." InvalidStructure,
"Number of row and col indices must be the same.",
)); ));
} else if col_indices.len() != values.len() { } else if col_indices.len() != values.len() {
return Err(SparseFormatError::from_kind_and_msg( return Err(SparseFormatError::from_kind_and_msg(
InvalidStructure, "Number of col indices and values must be the same." InvalidStructure,
"Number of col indices and values must be the same.",
)); ));
} }
@ -90,9 +91,15 @@ impl<T> CooMatrix<T>
let col_indices_in_bounds = col_indices.iter().all(|j| *j < ncols); let col_indices_in_bounds = col_indices.iter().all(|j| *j < ncols);
if !row_indices_in_bounds { if !row_indices_in_bounds {
Err(SparseFormatError::from_kind_and_msg(IndexOutOfBounds, "Row index out of bounds.")) Err(SparseFormatError::from_kind_and_msg(
IndexOutOfBounds,
"Row index out of bounds.",
))
} else if !col_indices_in_bounds { } else if !col_indices_in_bounds {
Err(SparseFormatError::from_kind_and_msg(IndexOutOfBounds, "Col index out of bounds.")) Err(SparseFormatError::from_kind_and_msg(
IndexOutOfBounds,
"Col index out of bounds.",
))
} else { } else {
Ok(Self { Ok(Self {
nrows, nrows,

View File

@ -5,8 +5,8 @@ use num_traits::One;
use nalgebra::Scalar; use nalgebra::Scalar;
use crate::{SparseEntry, SparseEntryMut};
use crate::pattern::SparsityPattern; use crate::pattern::SparsityPattern;
use crate::{SparseEntry, SparseEntryMut};
/// An abstract compressed matrix. /// An abstract compressed matrix.
/// ///
@ -18,7 +18,7 @@ use crate::pattern::SparsityPattern;
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsMatrix<T> { pub struct CsMatrix<T> {
sparsity_pattern: SparsityPattern, sparsity_pattern: SparsityPattern,
values: Vec<T> values: Vec<T>,
} }
impl<T> CsMatrix<T> { impl<T> CsMatrix<T> {
@ -50,14 +50,22 @@ impl<T> CsMatrix<T> {
#[inline] #[inline]
pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) { pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) {
let pattern = self.pattern(); let pattern = self.pattern();
(pattern.major_offsets(), pattern.minor_indices(), &self.values) (
pattern.major_offsets(),
pattern.minor_indices(),
&self.values,
)
} }
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`. /// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
#[inline] #[inline]
pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) { pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
let pattern = &mut self.sparsity_pattern; let pattern = &mut self.sparsity_pattern;
(pattern.major_offsets(), pattern.minor_indices(), &mut self.values) (
pattern.major_offsets(),
pattern.minor_indices(),
&mut self.values,
)
} }
#[inline] #[inline]
@ -66,9 +74,12 @@ impl<T> CsMatrix<T> {
} }
#[inline] #[inline]
pub fn from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>) pub fn from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>) -> Self {
-> Self { assert_eq!(
assert_eq!(pattern.nnz(), values.len(), "Internal error: consumers should verify shape compatibility."); pattern.nnz(),
values.len(),
"Internal error: consumers should verify shape compatibility."
);
Self { Self {
sparsity_pattern: pattern, sparsity_pattern: pattern,
values, values,
@ -105,13 +116,21 @@ impl<T> CsMatrix<T> {
let (_, minor_indices, values) = self.cs_data(); let (_, minor_indices, values) = self.cs_data();
let minor_indices = &minor_indices[row_range.clone()]; let minor_indices = &minor_indices[row_range.clone()];
let values = &values[row_range]; let values = &values[row_range];
get_entry_from_slices(self.pattern().minor_dim(), minor_indices, values, minor_index) get_entry_from_slices(
self.pattern().minor_dim(),
minor_indices,
values,
minor_index,
)
} }
/// Returns a mutable entry for the given major/minor indices, or `None` if the indices are out /// Returns a mutable entry for the given major/minor indices, or `None` if the indices are out
/// of bounds. /// of bounds.
pub fn get_entry_mut(&mut self, major_index: usize, minor_index: usize) pub fn get_entry_mut(
-> Option<SparseEntryMut<T>> { &mut self,
major_index: usize,
minor_index: usize,
) -> Option<SparseEntryMut<T>> {
let row_range = self.get_index_range(major_index)?; let row_range = self.get_index_range(major_index)?;
let minor_dim = self.pattern().minor_dim(); let minor_dim = self.pattern().minor_dim();
let (_, minor_indices, values) = self.cs_data_mut(); let (_, minor_indices, values) = self.cs_data_mut();
@ -126,7 +145,7 @@ impl<T> CsMatrix<T> {
Some(CsLane { Some(CsLane {
minor_indices: &minor_indices[range.clone()], minor_indices: &minor_indices[range.clone()],
values: &values[range], values: &values[range],
minor_dim: self.pattern().minor_dim() minor_dim: self.pattern().minor_dim(),
}) })
} }
@ -138,7 +157,7 @@ impl<T> CsMatrix<T> {
Some(CsLaneMut { Some(CsLaneMut {
minor_dim, minor_dim,
minor_indices: &minor_indices[range.clone()], minor_indices: &minor_indices[range.clone()],
values: &mut values[range] values: &mut values[range],
}) })
} }
@ -156,7 +175,7 @@ impl<T> CsMatrix<T> {
pub fn filter<P>(&self, predicate: P) -> Self pub fn filter<P>(&self, predicate: P) -> Self
where where
T: Clone, T: Clone,
P: Fn(usize, usize, &T) -> bool P: Fn(usize, usize, &T) -> bool,
{ {
let (major_dim, minor_dim) = (self.pattern().major_dim(), self.pattern().minor_dim()); let (major_dim, minor_dim) = (self.pattern().major_dim(), self.pattern().minor_dim());
let mut new_offsets = Vec::with_capacity(self.pattern().major_dim() + 1); let mut new_offsets = Vec::with_capacity(self.pattern().major_dim() + 1);
@ -180,7 +199,8 @@ impl<T> CsMatrix<T> {
major_dim, major_dim,
minor_dim, minor_dim,
new_offsets, new_offsets,
new_indices) new_indices,
)
.expect("Internal error: Sparsity pattern must always be valid."); .expect("Internal error: Sparsity pattern must always be valid.");
Self::from_pattern_and_values(new_pattern, new_values) Self::from_pattern_and_values(new_pattern, new_values)
@ -189,7 +209,7 @@ impl<T> CsMatrix<T> {
/// Returns the diagonal of the matrix as a sparse matrix. /// Returns the diagonal of the matrix as a sparse matrix.
pub fn diagonal_as_matrix(&self) -> Self pub fn diagonal_as_matrix(&self) -> Self
where where
T: Clone T: Clone,
{ {
// TODO: This might be faster with a binary search for each diagonal entry // TODO: This might be faster with a binary search for each diagonal entry
self.filter(|i, j, _| i == j) self.filter(|i, j, _| i == j)
@ -204,8 +224,8 @@ impl<T: Scalar + One> CsMatrix<T> {
let values = vec![T::one(); n]; let values = vec![T::one(); n];
// TODO: We should skip checks here // TODO: We should skip checks here
let pattern = SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices) let pattern =
.unwrap(); SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices).unwrap();
Self::from_pattern_and_values(pattern, values) Self::from_pattern_and_values(pattern, values)
} }
} }
@ -214,7 +234,8 @@ fn get_entry_from_slices<'a, T>(
minor_dim: usize, minor_dim: usize,
minor_indices: &'a [usize], minor_indices: &'a [usize],
values: &'a [T], values: &'a [T],
global_minor_index: usize) -> Option<SparseEntry<'a, T>> { global_minor_index: usize,
) -> Option<SparseEntry<'a, T>> {
let local_index = minor_indices.binary_search(&global_minor_index); let local_index = minor_indices.binary_search(&global_minor_index);
if let Ok(local_index) = local_index { if let Ok(local_index) = local_index {
Some(SparseEntry::NonZero(&values[local_index])) Some(SparseEntry::NonZero(&values[local_index]))
@ -229,7 +250,8 @@ fn get_mut_entry_from_slices<'a, T>(
minor_dim: usize, minor_dim: usize,
minor_indices: &'a [usize], minor_indices: &'a [usize],
values: &'a mut [T], values: &'a mut [T],
global_minor_indices: usize) -> Option<SparseEntryMut<'a, T>> { global_minor_indices: usize,
) -> Option<SparseEntryMut<'a, T>> {
let local_index = minor_indices.binary_search(&global_minor_indices); let local_index = minor_indices.binary_search(&global_minor_indices);
if let Ok(local_index) = local_index { if let Ok(local_index) = local_index {
Some(SparseEntryMut::NonZero(&mut values[local_index])) Some(SparseEntryMut::NonZero(&mut values[local_index]))
@ -244,14 +266,14 @@ fn get_mut_entry_from_slices<'a, T>(
pub struct CsLane<'a, T> { pub struct CsLane<'a, T> {
minor_dim: usize, minor_dim: usize,
minor_indices: &'a [usize], minor_indices: &'a [usize],
values: &'a [T] values: &'a [T],
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct CsLaneMut<'a, T> { pub struct CsLaneMut<'a, T> {
minor_dim: usize, minor_dim: usize,
minor_indices: &'a [usize], minor_indices: &'a [usize],
values: &'a mut [T] values: &'a mut [T],
} }
pub struct CsLaneIter<'a, T> { pub struct CsLaneIter<'a, T> {
@ -266,14 +288,14 @@ impl<'a, T> CsLaneIter<'a, T> {
Self { Self {
current_lane_idx: 0, current_lane_idx: 0,
pattern, pattern,
remaining_values: values remaining_values: values,
} }
} }
} }
impl<'a, T> Iterator for CsLaneIter<'a, T> impl<'a, T> Iterator for CsLaneIter<'a, T>
where where
T: 'a T: 'a,
{ {
type Item = CsLane<'a, T>; type Item = CsLane<'a, T>;
@ -290,7 +312,7 @@ impl<'a, T> Iterator for CsLaneIter<'a, T>
Some(CsLane { Some(CsLane {
minor_dim, minor_dim,
minor_indices, minor_indices,
values: values_in_lane values: values_in_lane,
}) })
} else { } else {
None None
@ -310,14 +332,14 @@ impl<'a, T> CsLaneIterMut<'a, T> {
Self { Self {
current_lane_idx: 0, current_lane_idx: 0,
pattern, pattern,
remaining_values: values remaining_values: values,
} }
} }
} }
impl<'a, T> Iterator for CsLaneIterMut<'a, T> impl<'a, T> Iterator for CsLaneIterMut<'a, T>
where where
T: 'a T: 'a,
{ {
type Item = CsLaneMut<'a, T>; type Item = CsLaneMut<'a, T>;
@ -336,7 +358,7 @@ impl<'a, T> Iterator for CsLaneIterMut<'a, T>
Some(CsLaneMut { Some(CsLaneMut {
minor_dim, minor_dim,
minor_indices, minor_indices,
values: values_in_lane values: values_in_lane,
}) })
} else { } else {
None None
@ -375,10 +397,11 @@ macro_rules! impl_cs_lane_common_methods {
self.minor_dim, self.minor_dim,
self.minor_indices, self.minor_indices,
self.values, self.values,
global_col_index) global_col_index,
} )
} }
} }
};
} }
impl_cs_lane_common_methods!(CsLane<'a, T>); impl_cs_lane_common_methods!(CsLane<'a, T>);
@ -394,10 +417,12 @@ impl<'a, T> CsLaneMut<'a, T> {
} }
pub fn get_entry_mut(&mut self, global_minor_index: usize) -> Option<SparseEntryMut<T>> { pub fn get_entry_mut(&mut self, global_minor_index: usize) -> Option<SparseEntryMut<T>> {
get_mut_entry_from_slices(self.minor_dim, get_mut_entry_from_slices(
self.minor_dim,
self.minor_indices, self.minor_indices,
self.values, self.values,
global_minor_index) global_minor_index,
)
} }
} }
@ -405,7 +430,7 @@ impl<'a, T> CsLaneMut<'a, T> {
/// TODO: This doesn't belong here. /// TODO: This doesn't belong here.
struct UninitVec<T> { struct UninitVec<T> {
vec: Vec<T>, vec: Vec<T>,
len: usize len: usize,
} }
impl<T> UninitVec<T> { impl<T> UninitVec<T> {
@ -414,7 +439,7 @@ impl<T> UninitVec<T> {
vec: Vec::with_capacity(len), vec: Vec::with_capacity(len),
// We need to store len separately, because for zero-sized types, // We need to store len separately, because for zero-sized types,
// Vec::with_capacity(len) does not give vec.capacity() == len // Vec::with_capacity(len) does not give vec.capacity() == len
len len,
} }
} }
@ -444,10 +469,10 @@ pub fn transpose_cs<T>(
minor_dim: usize, minor_dim: usize,
source_major_offsets: &[usize], source_major_offsets: &[usize],
source_minor_indices: &[usize], source_minor_indices: &[usize],
values: &[T]) values: &[T],
-> (Vec<usize>, Vec<usize>, Vec<T>) ) -> (Vec<usize>, Vec<usize>, Vec<T>)
where where
T: Scalar T: Scalar,
{ {
assert_eq!(source_major_offsets.len(), major_dim + 1); assert_eq!(source_major_offsets.len(), major_dim + 1);
assert_eq!(source_minor_indices.len(), values.len()); assert_eq!(source_minor_indices.len(), values.len());
@ -481,7 +506,9 @@ where
let target_lane_count = &mut current_target_major_counts[source_minor_idx]; let target_lane_count = &mut current_target_major_counts[source_minor_idx];
let entry_offset = target_offsets[source_minor_idx] + *target_lane_count; let entry_offset = target_offsets[source_minor_idx] + *target_lane_count;
target_indices[entry_offset] = source_major_idx; target_indices[entry_offset] = source_major_idx;
unsafe { target_values.set(entry_offset, val.inlined_clone()); } unsafe {
target_values.set(entry_offset, val.inlined_clone());
}
*target_lane_count += 1; *target_lane_count += 1;
} }
} }

View File

@ -3,14 +3,14 @@
//! This is the module-level documentation. See [`CscMatrix`] for the main documentation of the //! This is the module-level documentation. See [`CscMatrix`] for the main documentation of the
//! CSC implementation. //! CSC implementation.
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut}; use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::csr::CsrMatrix; use crate::csr::CsrMatrix;
use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
use std::slice::{IterMut, Iter};
use num_traits::{One};
use nalgebra::Scalar; use nalgebra::Scalar;
use num_traits::One;
use std::slice::{Iter, IterMut};
/// A CSC representation of a sparse matrix. /// A CSC representation of a sparse matrix.
/// ///
@ -130,7 +130,7 @@ impl<T> CscMatrix<T> {
/// Create a zero CSC matrix with no explicitly stored entries. /// Create a zero CSC matrix with no explicitly stored entries.
pub fn zeros(nrows: usize, ncols: usize) -> Self { pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self { Self {
cs: CsMatrix::new(ncols, nrows) cs: CsMatrix::new(ncols, nrows),
} }
} }
@ -196,7 +196,11 @@ impl<T> CscMatrix<T> {
values: Vec<T>, values: Vec<T>,
) -> Result<Self, SparseFormatError> { ) -> Result<Self, SparseFormatError> {
let pattern = SparsityPattern::try_from_offsets_and_indices( let pattern = SparsityPattern::try_from_offsets_and_indices(
num_cols, num_rows, col_offsets, row_indices) num_cols,
num_rows,
col_offsets,
row_indices,
)
.map_err(pattern_format_error_to_csc_error)?; .map_err(pattern_format_error_to_csc_error)?;
Self::try_from_pattern_and_values(pattern, values) Self::try_from_pattern_and_values(pattern, values)
} }
@ -205,16 +209,19 @@ impl<T> CscMatrix<T> {
/// ///
/// Returns an error if the number of values does not match the number of minor indices /// Returns an error if the number of values does not match the number of minor indices
/// in the pattern. /// in the pattern.
pub fn try_from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>) pub fn try_from_pattern_and_values(
-> Result<Self, SparseFormatError> { pattern: SparsityPattern,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() { if pattern.nnz() == values.len() {
Ok(Self { Ok(Self {
cs: CsMatrix::from_pattern_and_values(pattern, values) cs: CsMatrix::from_pattern_and_values(pattern, values),
}) })
} else { } else {
Err(SparseFormatError::from_kind_and_msg( Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure, SparseFormatErrorKind::InvalidStructure,
"Number of values and row indices must be the same")) "Number of values and row indices must be the same",
))
} }
} }
@ -239,7 +246,7 @@ impl<T> CscMatrix<T> {
pub fn triplet_iter(&self) -> CscTripletIter<T> { pub fn triplet_iter(&self) -> CscTripletIter<T> {
CscTripletIter { CscTripletIter {
pattern_iter: self.pattern().entries(), pattern_iter: self.pattern().entries(),
values_iter: self.values().iter() values_iter: self.values().iter(),
} }
} }
@ -270,7 +277,7 @@ impl<T> CscMatrix<T> {
let (pattern, values) = self.cs.pattern_and_values_mut(); let (pattern, values) = self.cs.pattern_and_values_mut();
CscTripletIterMut { CscTripletIterMut {
pattern_iter: pattern.entries(), pattern_iter: pattern.entries(),
values_mut_iter: values.iter_mut() values_mut_iter: values.iter_mut(),
} }
} }
@ -281,8 +288,7 @@ impl<T> CscMatrix<T> {
/// Panics if column index is out of bounds. /// Panics if column index is out of bounds.
#[inline] #[inline]
pub fn col(&self, index: usize) -> CscCol<T> { pub fn col(&self, index: usize) -> CscCol<T> {
self.get_col(index) self.get_col(index).expect("Row index must be in bounds")
.expect("Row index must be in bounds")
} }
/// Mutable column access for the given column index. /// Mutable column access for the given column index.
@ -299,23 +305,19 @@ impl<T> CscMatrix<T> {
/// Return the column at the given column index, or `None` if out of bounds. /// Return the column at the given column index, or `None` if out of bounds.
#[inline] #[inline]
pub fn get_col(&self, index: usize) -> Option<CscCol<T>> { pub fn get_col(&self, index: usize) -> Option<CscCol<T>> {
self.cs self.cs.get_lane(index).map(|lane| CscCol { lane })
.get_lane(index)
.map(|lane| CscCol { lane })
} }
/// Mutable column access for the given column index, or `None` if out of bounds. /// Mutable column access for the given column index, or `None` if out of bounds.
#[inline] #[inline]
pub fn get_col_mut(&mut self, index: usize) -> Option<CscColMut<T>> { pub fn get_col_mut(&mut self, index: usize) -> Option<CscColMut<T>> {
self.cs self.cs.get_lane_mut(index).map(|lane| CscColMut { lane })
.get_lane_mut(index)
.map(|lane| CscColMut { lane })
} }
/// An iterator over columns in the matrix. /// An iterator over columns in the matrix.
pub fn col_iter(&self) -> CscColIter<T> { pub fn col_iter(&self) -> CscColIter<T> {
CscColIter { CscColIter {
lane_iter: CsLaneIter::new(self.pattern(), self.values()) lane_iter: CsLaneIter::new(self.pattern(), self.values()),
} }
} }
@ -323,7 +325,7 @@ impl<T> CscMatrix<T> {
pub fn col_iter_mut(&mut self) -> CscColIterMut<T> { pub fn col_iter_mut(&mut self) -> CscColIterMut<T> {
let (pattern, values) = self.cs.pattern_and_values_mut(); let (pattern, values) = self.cs.pattern_and_values_mut();
CscColIterMut { CscColIterMut {
lane_iter: CsLaneIterMut::new(pattern, values) lane_iter: CsLaneIterMut::new(pattern, values),
} }
} }
@ -397,8 +399,11 @@ impl<T> CscMatrix<T> {
/// ///
/// Each call to this function incurs the cost of a binary search among the explicitly /// Each call to this function incurs the cost of a binary search among the explicitly
/// stored row entries for the given column. /// stored row entries for the given column.
pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize) pub fn get_entry_mut(
-> Option<SparseEntryMut<T>> { &mut self,
row_index: usize,
col_index: usize,
) -> Option<SparseEntryMut<T>> {
self.cs.get_entry_mut(col_index, row_index) self.cs.get_entry_mut(col_index, row_index)
} }
@ -444,11 +449,15 @@ impl<T> CscMatrix<T> {
pub fn filter<P>(&self, predicate: P) -> Self pub fn filter<P>(&self, predicate: P) -> Self
where where
T: Clone, T: Clone,
P: Fn(usize, usize, &T) -> bool P: Fn(usize, usize, &T) -> bool,
{ {
// Note: Predicate uses (row, col, value), so we have to switch around since // Note: Predicate uses (row, col, value), so we have to switch around since
// cs uses (major, minor, value) // cs uses (major, minor, value)
Self { cs: self.cs.filter(|col_idx, row_idx, v| predicate(row_idx, col_idx, v)) } Self {
cs: self
.cs
.filter(|col_idx, row_idx, v| predicate(row_idx, col_idx, v)),
}
} }
/// Returns a new matrix representing the upper triangular part of this matrix. /// Returns a new matrix representing the upper triangular part of this matrix.
@ -456,7 +465,7 @@ impl<T> CscMatrix<T> {
/// The result includes the diagonal of the matrix. /// The result includes the diagonal of the matrix.
pub fn upper_triangle(&self) -> Self pub fn upper_triangle(&self) -> Self
where where
T: Clone T: Clone,
{ {
self.filter(|i, j, _| i <= j) self.filter(|i, j, _| i <= j)
} }
@ -466,7 +475,7 @@ impl<T> CscMatrix<T> {
/// The result includes the diagonal of the matrix. /// The result includes the diagonal of the matrix.
pub fn lower_triangle(&self) -> Self pub fn lower_triangle(&self) -> Self
where where
T: Clone T: Clone,
{ {
self.filter(|i, j, _| i >= j) self.filter(|i, j, _| i >= j)
} }
@ -474,15 +483,17 @@ impl<T> CscMatrix<T> {
/// Returns the diagonal of the matrix as a sparse matrix. /// Returns the diagonal of the matrix as a sparse matrix.
pub fn diagonal_as_csc(&self) -> Self pub fn diagonal_as_csc(&self) -> Self
where where
T: Clone T: Clone,
{ {
Self { cs: self.cs.diagonal_as_matrix() } Self {
cs: self.cs.diagonal_as_matrix(),
}
} }
} }
impl<T> CscMatrix<T> impl<T> CscMatrix<T>
where where
T: Scalar T: Scalar,
{ {
/// Compute the transpose of the matrix. /// Compute the transpose of the matrix.
pub fn transpose(&self) -> CscMatrix<T> { pub fn transpose(&self) -> CscMatrix<T> {
@ -495,7 +506,7 @@ impl<T: Scalar + One> CscMatrix<T> {
#[inline] #[inline]
pub fn identity(n: usize) -> Self { pub fn identity(n: usize) -> Self {
Self { Self {
cs: CsMatrix::identity(n) cs: CsMatrix::identity(n),
} }
} }
} }
@ -505,30 +516,34 @@ impl<T: Scalar + One> CscMatrix<T> {
/// This ensures that the terminology is consistent: we are talking about rows and columns, /// This ensures that the terminology is consistent: we are talking about rows and columns,
/// not lanes, major and minor dimensions. /// not lanes, major and minor dimensions.
fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseFormatError { fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseFormatError {
use SparsityPatternFormatError::*;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparseFormatError as E; use SparseFormatError as E;
use SparseFormatErrorKind as K; use SparseFormatErrorKind as K;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparsityPatternFormatError::*;
match err { match err {
InvalidOffsetArrayLength => E::from_kind_and_msg( InvalidOffsetArrayLength => E::from_kind_and_msg(
K::InvalidStructure, K::InvalidStructure,
"Length of col offset array is not equal to ncols + 1."), "Length of col offset array is not equal to ncols + 1.",
),
InvalidOffsetFirstLast => E::from_kind_and_msg( InvalidOffsetFirstLast => E::from_kind_and_msg(
K::InvalidStructure, K::InvalidStructure,
"First or last col offset is inconsistent with format specification."), "First or last col offset is inconsistent with format specification.",
),
NonmonotonicOffsets => E::from_kind_and_msg( NonmonotonicOffsets => E::from_kind_and_msg(
K::InvalidStructure, K::InvalidStructure,
"Col offsets are not monotonically increasing."), "Col offsets are not monotonically increasing.",
),
NonmonotonicMinorIndices => E::from_kind_and_msg( NonmonotonicMinorIndices => E::from_kind_and_msg(
K::InvalidStructure, K::InvalidStructure,
"Row indices are not monotonically increasing (sorted) within each column."), "Row indices are not monotonically increasing (sorted) within each column.",
MinorIndexOutOfBounds => E::from_kind_and_msg( ),
K::IndexOutOfBounds, MinorIndexOutOfBounds => {
"Row indices are out of bounds."), E::from_kind_and_msg(K::IndexOutOfBounds, "Row indices are out of bounds.")
PatternDuplicateEntry => E::from_kind_and_msg( }
K::DuplicateEntry, PatternDuplicateEntry => {
"Matrix data contains duplicate entries."), E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
}
} }
} }
@ -536,7 +551,7 @@ fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseF
#[derive(Debug)] #[derive(Debug)]
pub struct CscTripletIter<'a, T> { pub struct CscTripletIter<'a, T> {
pattern_iter: SparsityPatternIter<'a>, pattern_iter: SparsityPatternIter<'a>,
values_iter: Iter<'a, T> values_iter: Iter<'a, T>,
} }
impl<'a, T: Clone> CscTripletIter<'a, T> { impl<'a, T: Clone> CscTripletIter<'a, T> {
@ -559,7 +574,7 @@ impl<'a, T> Iterator for CscTripletIter<'a, T> {
match (next_entry, next_value) { match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((j, i, v)), (Some((i, j)), Some(v)) => Some((j, i, v)),
_ => None _ => None,
} }
} }
} }
@ -568,7 +583,7 @@ impl<'a, T> Iterator for CscTripletIter<'a, T> {
#[derive(Debug)] #[derive(Debug)]
pub struct CscTripletIterMut<'a, T> { pub struct CscTripletIterMut<'a, T> {
pattern_iter: SparsityPatternIter<'a>, pattern_iter: SparsityPatternIter<'a>,
values_mut_iter: IterMut<'a, T> values_mut_iter: IterMut<'a, T>,
} }
impl<'a, T> Iterator for CscTripletIterMut<'a, T> { impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
@ -581,7 +596,7 @@ impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
match (next_entry, next_value) { match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((j, i, v)), (Some((i, j)), Some(v)) => Some((j, i, v)),
_ => None _ => None,
} }
} }
} }
@ -589,7 +604,7 @@ impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
/// An immutable representation of a column in a CSC matrix. /// An immutable representation of a column in a CSC matrix.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct CscCol<'a, T> { pub struct CscCol<'a, T> {
lane: CsLane<'a, T> lane: CsLane<'a, T>,
} }
/// A mutable representation of a column in a CSC matrix. /// A mutable representation of a column in a CSC matrix.
@ -598,7 +613,7 @@ pub struct CscCol<'a, T> {
/// to the column cannot be modified. /// to the column cannot be modified.
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct CscColMut<'a, T> { pub struct CscColMut<'a, T> {
lane: CsLaneMut<'a, T> lane: CsLaneMut<'a, T>,
} }
/// Implement the methods common to both CscCol and CscColMut /// Implement the methods common to both CscCol and CscColMut
@ -637,7 +652,7 @@ macro_rules! impl_csc_col_common_methods {
self.lane.get_entry(global_row_index) self.lane.get_entry(global_row_index)
} }
} }
} };
} }
impl_csc_col_common_methods!(CscCol<'a, T>); impl_csc_col_common_methods!(CscCol<'a, T>);
@ -666,33 +681,29 @@ impl<'a, T> CscColMut<'a, T> {
/// Column iterator for [CscMatrix](struct.CscMatrix.html). /// Column iterator for [CscMatrix](struct.CscMatrix.html).
pub struct CscColIter<'a, T> { pub struct CscColIter<'a, T> {
lane_iter: CsLaneIter<'a, T> lane_iter: CsLaneIter<'a, T>,
} }
impl<'a, T> Iterator for CscColIter<'a, T> { impl<'a, T> Iterator for CscColIter<'a, T> {
type Item = CscCol<'a, T>; type Item = CscCol<'a, T>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
self.lane_iter self.lane_iter.next().map(|lane| CscCol { lane })
.next()
.map(|lane| CscCol { lane })
} }
} }
/// Mutable column iterator for [CscMatrix](struct.CscMatrix.html). /// Mutable column iterator for [CscMatrix](struct.CscMatrix.html).
pub struct CscColIterMut<'a, T> { pub struct CscColIterMut<'a, T> {
lane_iter: CsLaneIterMut<'a, T> lane_iter: CsLaneIterMut<'a, T>,
} }
impl<'a, T> Iterator for CscColIterMut<'a, T> impl<'a, T> Iterator for CscColIterMut<'a, T>
where where
T: 'a T: 'a,
{ {
type Item = CscColMut<'a, T>; type Item = CscColMut<'a, T>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
self.lane_iter self.lane_iter.next().map(|lane| CscColMut { lane })
.next()
.map(|lane| CscColMut { lane })
} }
} }

View File

@ -2,15 +2,15 @@
//! //!
//! This is the module-level documentation. See [`CsrMatrix`] for the main documentation of the //! This is the module-level documentation. See [`CsrMatrix`] for the main documentation of the
//! CSC implementation. //! CSC implementation.
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut}; use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use crate::cs::{CsMatrix, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
use nalgebra::Scalar; use nalgebra::Scalar;
use num_traits::{One}; use num_traits::One;
use std::slice::{IterMut, Iter}; use std::slice::{Iter, IterMut};
/// A CSR representation of a sparse matrix. /// A CSR representation of a sparse matrix.
/// ///
@ -130,7 +130,7 @@ impl<T> CsrMatrix<T> {
/// Create a zero CSR matrix with no explicitly stored entries. /// Create a zero CSR matrix with no explicitly stored entries.
pub fn zeros(nrows: usize, ncols: usize) -> Self { pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self { Self {
cs: CsMatrix::new(nrows, ncols) cs: CsMatrix::new(nrows, ncols),
} }
} }
@ -198,7 +198,11 @@ impl<T> CsrMatrix<T> {
values: Vec<T>, values: Vec<T>,
) -> Result<Self, SparseFormatError> { ) -> Result<Self, SparseFormatError> {
let pattern = SparsityPattern::try_from_offsets_and_indices( let pattern = SparsityPattern::try_from_offsets_and_indices(
num_rows, num_cols, row_offsets, col_indices) num_rows,
num_cols,
row_offsets,
col_indices,
)
.map_err(pattern_format_error_to_csr_error)?; .map_err(pattern_format_error_to_csr_error)?;
Self::try_from_pattern_and_values(pattern, values) Self::try_from_pattern_and_values(pattern, values)
} }
@ -207,16 +211,19 @@ impl<T> CsrMatrix<T> {
/// ///
/// Returns an error if the number of values does not match the number of minor indices /// Returns an error if the number of values does not match the number of minor indices
/// in the pattern. /// in the pattern.
pub fn try_from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>) pub fn try_from_pattern_and_values(
-> Result<Self, SparseFormatError> { pattern: SparsityPattern,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() { if pattern.nnz() == values.len() {
Ok(Self { Ok(Self {
cs: CsMatrix::from_pattern_and_values(pattern, values) cs: CsMatrix::from_pattern_and_values(pattern, values),
}) })
} else { } else {
Err(SparseFormatError::from_kind_and_msg( Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure, SparseFormatErrorKind::InvalidStructure,
"Number of values and column indices must be the same")) "Number of values and column indices must be the same",
))
} }
} }
@ -241,7 +248,7 @@ impl<T> CsrMatrix<T> {
pub fn triplet_iter(&self) -> CsrTripletIter<T> { pub fn triplet_iter(&self) -> CsrTripletIter<T> {
CsrTripletIter { CsrTripletIter {
pattern_iter: self.pattern().entries(), pattern_iter: self.pattern().entries(),
values_iter: self.values().iter() values_iter: self.values().iter(),
} }
} }
@ -272,7 +279,7 @@ impl<T> CsrMatrix<T> {
let (pattern, values) = self.cs.pattern_and_values_mut(); let (pattern, values) = self.cs.pattern_and_values_mut();
CsrTripletIterMut { CsrTripletIterMut {
pattern_iter: pattern.entries(), pattern_iter: pattern.entries(),
values_mut_iter: values.iter_mut() values_mut_iter: values.iter_mut(),
} }
} }
@ -283,8 +290,7 @@ impl<T> CsrMatrix<T> {
/// Panics if row index is out of bounds. /// Panics if row index is out of bounds.
#[inline] #[inline]
pub fn row(&self, index: usize) -> CsrRow<T> { pub fn row(&self, index: usize) -> CsrRow<T> {
self.get_row(index) self.get_row(index).expect("Row index must be in bounds")
.expect("Row index must be in bounds")
} }
/// Mutable row access for the given row index. /// Mutable row access for the given row index.
@ -301,23 +307,19 @@ impl<T> CsrMatrix<T> {
/// Return the row at the given row index, or `None` if out of bounds. /// Return the row at the given row index, or `None` if out of bounds.
#[inline] #[inline]
pub fn get_row(&self, index: usize) -> Option<CsrRow<T>> { pub fn get_row(&self, index: usize) -> Option<CsrRow<T>> {
self.cs self.cs.get_lane(index).map(|lane| CsrRow { lane })
.get_lane(index)
.map(|lane| CsrRow { lane })
} }
/// Mutable row access for the given row index, or `None` if out of bounds. /// Mutable row access for the given row index, or `None` if out of bounds.
#[inline] #[inline]
pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<T>> { pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<T>> {
self.cs self.cs.get_lane_mut(index).map(|lane| CsrRowMut { lane })
.get_lane_mut(index)
.map(|lane| CsrRowMut { lane })
} }
/// An iterator over rows in the matrix. /// An iterator over rows in the matrix.
pub fn row_iter(&self) -> CsrRowIter<T> { pub fn row_iter(&self) -> CsrRowIter<T> {
CsrRowIter { CsrRowIter {
lane_iter: CsLaneIter::new(self.pattern(), self.values()) lane_iter: CsLaneIter::new(self.pattern(), self.values()),
} }
} }
@ -399,8 +401,11 @@ impl<T> CsrMatrix<T> {
/// ///
/// Each call to this function incurs the cost of a binary search among the explicitly /// Each call to this function incurs the cost of a binary search among the explicitly
/// stored column entries for the given row. /// stored column entries for the given row.
pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize) pub fn get_entry_mut(
-> Option<SparseEntryMut<T>> { &mut self,
row_index: usize,
col_index: usize,
) -> Option<SparseEntryMut<T>> {
self.cs.get_entry_mut(row_index, col_index) self.cs.get_entry_mut(row_index, col_index)
} }
@ -446,9 +451,13 @@ impl<T> CsrMatrix<T> {
pub fn filter<P>(&self, predicate: P) -> Self pub fn filter<P>(&self, predicate: P) -> Self
where where
T: Clone, T: Clone,
P: Fn(usize, usize, &T) -> bool P: Fn(usize, usize, &T) -> bool,
{ {
Self { cs: self.cs.filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)) } Self {
cs: self
.cs
.filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)),
}
} }
/// Returns a new matrix representing the upper triangular part of this matrix. /// Returns a new matrix representing the upper triangular part of this matrix.
@ -456,7 +465,7 @@ impl<T> CsrMatrix<T> {
/// The result includes the diagonal of the matrix. /// The result includes the diagonal of the matrix.
pub fn upper_triangle(&self) -> Self pub fn upper_triangle(&self) -> Self
where where
T: Clone T: Clone,
{ {
self.filter(|i, j, _| i <= j) self.filter(|i, j, _| i <= j)
} }
@ -466,7 +475,7 @@ impl<T> CsrMatrix<T> {
/// The result includes the diagonal of the matrix. /// The result includes the diagonal of the matrix.
pub fn lower_triangle(&self) -> Self pub fn lower_triangle(&self) -> Self
where where
T: Clone T: Clone,
{ {
self.filter(|i, j, _| i >= j) self.filter(|i, j, _| i >= j)
} }
@ -474,15 +483,17 @@ impl<T> CsrMatrix<T> {
/// Returns the diagonal of the matrix as a sparse matrix. /// Returns the diagonal of the matrix as a sparse matrix.
pub fn diagonal_as_csr(&self) -> Self pub fn diagonal_as_csr(&self) -> Self
where where
T: Clone T: Clone,
{ {
Self { cs: self.cs.diagonal_as_matrix() } Self {
cs: self.cs.diagonal_as_matrix(),
}
} }
} }
impl<T> CsrMatrix<T> impl<T> CsrMatrix<T>
where where
T: Scalar T: Scalar,
{ {
/// Compute the transpose of the matrix. /// Compute the transpose of the matrix.
pub fn transpose(&self) -> CsrMatrix<T> { pub fn transpose(&self) -> CsrMatrix<T> {
@ -495,7 +506,7 @@ impl<T: Scalar + One> CsrMatrix<T> {
#[inline] #[inline]
pub fn identity(n: usize) -> Self { pub fn identity(n: usize) -> Self {
Self { Self {
cs: CsMatrix::identity(n) cs: CsMatrix::identity(n),
} }
} }
} }
@ -505,30 +516,34 @@ impl<T: Scalar + One> CsrMatrix<T> {
/// This ensures that the terminology is consistent: we are talking about rows and columns, /// This ensures that the terminology is consistent: we are talking about rows and columns,
/// not lanes, major and minor dimensions. /// not lanes, major and minor dimensions.
fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseFormatError { fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseFormatError {
use SparsityPatternFormatError::*;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparseFormatError as E; use SparseFormatError as E;
use SparseFormatErrorKind as K; use SparseFormatErrorKind as K;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparsityPatternFormatError::*;
match err { match err {
InvalidOffsetArrayLength => E::from_kind_and_msg( InvalidOffsetArrayLength => E::from_kind_and_msg(
K::InvalidStructure, K::InvalidStructure,
"Length of row offset array is not equal to nrows + 1."), "Length of row offset array is not equal to nrows + 1.",
),
InvalidOffsetFirstLast => E::from_kind_and_msg( InvalidOffsetFirstLast => E::from_kind_and_msg(
K::InvalidStructure, K::InvalidStructure,
"First or last row offset is inconsistent with format specification."), "First or last row offset is inconsistent with format specification.",
),
NonmonotonicOffsets => E::from_kind_and_msg( NonmonotonicOffsets => E::from_kind_and_msg(
K::InvalidStructure, K::InvalidStructure,
"Row offsets are not monotonically increasing."), "Row offsets are not monotonically increasing.",
),
NonmonotonicMinorIndices => E::from_kind_and_msg( NonmonotonicMinorIndices => E::from_kind_and_msg(
K::InvalidStructure, K::InvalidStructure,
"Column indices are not monotonically increasing (sorted) within each row."), "Column indices are not monotonically increasing (sorted) within each row.",
MinorIndexOutOfBounds => E::from_kind_and_msg( ),
K::IndexOutOfBounds, MinorIndexOutOfBounds => {
"Column indices are out of bounds."), E::from_kind_and_msg(K::IndexOutOfBounds, "Column indices are out of bounds.")
PatternDuplicateEntry => E::from_kind_and_msg( }
K::DuplicateEntry, PatternDuplicateEntry => {
"Matrix data contains duplicate entries."), E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
}
} }
} }
@ -536,7 +551,7 @@ fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseF
#[derive(Debug)] #[derive(Debug)]
pub struct CsrTripletIter<'a, T> { pub struct CsrTripletIter<'a, T> {
pattern_iter: SparsityPatternIter<'a>, pattern_iter: SparsityPatternIter<'a>,
values_iter: Iter<'a, T> values_iter: Iter<'a, T>,
} }
impl<'a, T: Clone> CsrTripletIter<'a, T> { impl<'a, T: Clone> CsrTripletIter<'a, T> {
@ -559,7 +574,7 @@ impl<'a, T> Iterator for CsrTripletIter<'a, T> {
match (next_entry, next_value) { match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((i, j, v)), (Some((i, j)), Some(v)) => Some((i, j, v)),
_ => None _ => None,
} }
} }
} }
@ -568,7 +583,7 @@ impl<'a, T> Iterator for CsrTripletIter<'a, T> {
#[derive(Debug)] #[derive(Debug)]
pub struct CsrTripletIterMut<'a, T> { pub struct CsrTripletIterMut<'a, T> {
pattern_iter: SparsityPatternIter<'a>, pattern_iter: SparsityPatternIter<'a>,
values_mut_iter: IterMut<'a, T> values_mut_iter: IterMut<'a, T>,
} }
impl<'a, T> Iterator for CsrTripletIterMut<'a, T> { impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
@ -581,7 +596,7 @@ impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
match (next_entry, next_value) { match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((i, j, v)), (Some((i, j)), Some(v)) => Some((i, j, v)),
_ => None _ => None,
} }
} }
} }
@ -589,7 +604,7 @@ impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
/// An immutable representation of a row in a CSR matrix. /// An immutable representation of a row in a CSR matrix.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsrRow<'a, T> { pub struct CsrRow<'a, T> {
lane: CsLane<'a, T> lane: CsLane<'a, T>,
} }
/// A mutable representation of a row in a CSR matrix. /// A mutable representation of a row in a CSR matrix.
@ -598,7 +613,7 @@ pub struct CsrRow<'a, T> {
/// to the row cannot be modified. /// to the row cannot be modified.
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct CsrRowMut<'a, T> { pub struct CsrRowMut<'a, T> {
lane: CsLaneMut<'a, T> lane: CsLaneMut<'a, T>,
} }
/// Implement the methods common to both CsrRow and CsrRowMut /// Implement the methods common to both CsrRow and CsrRowMut
@ -638,7 +653,7 @@ macro_rules! impl_csr_row_common_methods {
self.lane.get_entry(global_col_index) self.lane.get_entry(global_col_index)
} }
} }
} };
} }
impl_csr_row_common_methods!(CsrRow<'a, T>); impl_csr_row_common_methods!(CsrRow<'a, T>);
@ -670,33 +685,29 @@ impl<'a, T> CsrRowMut<'a, T> {
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html). /// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
pub struct CsrRowIter<'a, T> { pub struct CsrRowIter<'a, T> {
lane_iter: CsLaneIter<'a, T> lane_iter: CsLaneIter<'a, T>,
} }
impl<'a, T> Iterator for CsrRowIter<'a, T> { impl<'a, T> Iterator for CsrRowIter<'a, T> {
type Item = CsrRow<'a, T>; type Item = CsrRow<'a, T>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
self.lane_iter self.lane_iter.next().map(|lane| CsrRow { lane })
.next()
.map(|lane| CsrRow { lane })
} }
} }
/// Mutable row iterator for [CsrMatrix](struct.CsrMatrix.html). /// Mutable row iterator for [CsrMatrix](struct.CsrMatrix.html).
pub struct CsrRowIterMut<'a, T> { pub struct CsrRowIterMut<'a, T> {
lane_iter: CsLaneIterMut<'a, T> lane_iter: CsLaneIterMut<'a, T>,
} }
impl<'a, T> Iterator for CsrRowIterMut<'a, T> impl<'a, T> Iterator for CsrRowIterMut<'a, T>
where where
T: 'a T: 'a,
{ {
type Item = CsrRowMut<'a, T>; type Item = CsrRowMut<'a, T>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
self.lane_iter self.lane_iter.next().map(|lane| CsrRowMut { lane })
.next()
.map(|lane| CsrRowMut { lane })
} }
} }

View File

@ -1,10 +1,10 @@
use crate::pattern::SparsityPattern;
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use core::{mem, iter};
use nalgebra::{Scalar, RealField, DMatrixSlice, DMatrixSliceMut, DMatrix};
use std::fmt::{Display, Formatter};
use crate::ops::serial::spsolve_csc_lower_triangular; use crate::ops::serial::spsolve_csc_lower_triangular;
use crate::ops::Op; use crate::ops::Op;
use crate::pattern::SparsityPattern;
use core::{iter, mem};
use nalgebra::{DMatrix, DMatrixSlice, DMatrixSliceMut, RealField, Scalar};
use std::fmt::{Display, Formatter};
/// A symbolic sparse Cholesky factorization of a CSC matrix. /// A symbolic sparse Cholesky factorization of a CSC matrix.
/// ///
@ -15,7 +15,7 @@ pub struct CscSymbolicCholesky {
m_pattern: SparsityPattern, m_pattern: SparsityPattern,
l_pattern: SparsityPattern, l_pattern: SparsityPattern,
// u in this context is L^T, so that M = L L^T // u in this context is L^T, so that M = L L^T
u_pattern: SparsityPattern u_pattern: SparsityPattern,
} }
impl CscSymbolicCholesky { impl CscSymbolicCholesky {
@ -28,8 +28,11 @@ impl CscSymbolicCholesky {
/// ///
/// Panics if the sparsity pattern is not square. /// Panics if the sparsity pattern is not square.
pub fn factor(pattern: SparsityPattern) -> Self { pub fn factor(pattern: SparsityPattern) -> Self {
assert_eq!(pattern.major_dim(), pattern.minor_dim(), assert_eq!(
"Major and minor dimensions must be the same (square matrix)."); pattern.major_dim(),
pattern.minor_dim(),
"Major and minor dimensions must be the same (square matrix)."
);
let (l_pattern, u_pattern) = nonzero_pattern(&pattern); let (l_pattern, u_pattern) = nonzero_pattern(&pattern);
Self { Self {
m_pattern: pattern, m_pattern: pattern,
@ -65,7 +68,7 @@ pub struct CscCholesky<T> {
l_factor: CscMatrix<T>, l_factor: CscMatrix<T>,
u_pattern: SparsityPattern, u_pattern: SparsityPattern,
work_x: Vec<T>, work_x: Vec<T>,
work_c: Vec<usize> work_c: Vec<usize>,
} }
#[derive(Debug, PartialEq, Eq, Clone)] #[derive(Debug, PartialEq, Eq, Clone)]
@ -100,16 +103,20 @@ impl<T: RealField> CscCholesky<T> {
/// ///
/// Panics if the number of values differ from the number of non-zeros of the sparsity pattern /// Panics if the number of values differ from the number of non-zeros of the sparsity pattern
/// of the matrix that was symbolically factored. /// of the matrix that was symbolically factored.
pub fn factor_numerical(symbolic: CscSymbolicCholesky, values: &[T]) pub fn factor_numerical(
-> Result<Self, CholeskyError> symbolic: CscSymbolicCholesky,
{ values: &[T],
assert_eq!(symbolic.l_pattern.nnz(), symbolic.u_pattern.nnz(), ) -> Result<Self, CholeskyError> {
"u is just the transpose of l, so should have the same nnz"); assert_eq!(
symbolic.l_pattern.nnz(),
symbolic.u_pattern.nnz(),
"u is just the transpose of l, so should have the same nnz"
);
let l_nnz = symbolic.l_pattern.nnz(); let l_nnz = symbolic.l_pattern.nnz();
let l_values = vec![T::zero(); l_nnz]; let l_values = vec![T::zero(); l_nnz];
let l_factor = CscMatrix::try_from_pattern_and_values(symbolic.l_pattern, l_values) let l_factor =
.unwrap(); CscMatrix::try_from_pattern_and_values(symbolic.l_pattern, l_values).unwrap();
let (nrows, ncols) = (l_factor.nrows(), l_factor.ncols()); let (nrows, ncols) = (l_factor.nrows(), l_factor.ncols());
@ -229,11 +236,9 @@ impl<T: RealField> CscCholesky<T> {
{ {
let (offsets, _, values) = self.l_factor.csc_data_mut(); let (offsets, _, values) = self.l_factor.csc_data_mut();
*values *values.get_unchecked_mut(*offsets.get_unchecked(k)) = denom;
.get_unchecked_mut(*offsets.get_unchecked(k)) = denom;
} }
let mut col_k = self.l_factor.col_mut(k); let mut col_k = self.l_factor.col_mut(k);
let (col_k_rows, col_k_values) = col_k.rows_and_values_mut(); let (col_k_rows, col_k_values) = col_k.rows_and_values_mut();
let col_k_entries = col_k_rows.iter().zip(col_k_values); let col_k_entries = col_k_rows.iter().zip(col_k_values);
@ -269,19 +274,16 @@ impl<T: RealField> CscCholesky<T> {
/// # Panics /// # Panics
/// ///
/// Panics if `b` is not square. /// Panics if `b` is not square.
pub fn solve_mut<'a>(&'a self, b: impl Into<DMatrixSliceMut<'a, T>>) pub fn solve_mut<'a>(&'a self, b: impl Into<DMatrixSliceMut<'a, T>>) {
{
let expect_msg = "If the Cholesky factorization succeeded,\ let expect_msg = "If the Cholesky factorization succeeded,\
then the triangular solve should never fail"; then the triangular solve should never fail";
// Solve LY = B // Solve LY = B
let mut y = b.into(); let mut y = b.into();
spsolve_csc_lower_triangular(Op::NoOp(self.l()), &mut y) spsolve_csc_lower_triangular(Op::NoOp(self.l()), &mut y).expect(expect_msg);
.expect(expect_msg);
// Solve L^T X = Y // Solve L^T X = Y
let mut x = y; let mut x = y;
spsolve_csc_lower_triangular(Op::Transpose(self.l()), &mut x) spsolve_csc_lower_triangular(Op::Transpose(self.l()), &mut x).expect(expect_msg);
.expect(expect_msg);
} }
} }
@ -333,8 +335,8 @@ fn nonzero_pattern(m: &SparsityPattern) -> (SparsityPattern, SparsityPattern) {
col_offsets.push(rows.len()); col_offsets.push(rows.len());
} }
let u_pattern = SparsityPattern::try_from_offsets_and_indices(nrows, ncols, col_offsets, rows) let u_pattern =
.unwrap(); SparsityPattern::try_from_offsets_and_indices(nrows, ncols, col_offsets, rows).unwrap();
// TODO: Avoid this transpose? // TODO: Avoid this transpose?
let l_pattern = u_pattern.transpose(); let l_pattern = u_pattern.transpose();

View File

@ -135,13 +135,13 @@
#![deny(unused_results)] #![deny(unused_results)]
#![deny(missing_docs)] #![deny(missing_docs)]
pub mod convert;
pub mod coo; pub mod coo;
pub mod csc; pub mod csc;
pub mod csr; pub mod csr;
pub mod pattern;
pub mod ops;
pub mod convert;
pub mod factorization; pub mod factorization;
pub mod ops;
pub mod pattern;
pub(crate) mod cs; pub(crate) mod cs;
@ -151,16 +151,16 @@ pub mod proptest;
#[cfg(feature = "compare")] #[cfg(feature = "compare")]
mod matrixcompare; mod matrixcompare;
use num_traits::Zero;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
use num_traits::Zero;
/// Errors produced by functions that expect well-formed sparse format data. /// Errors produced by functions that expect well-formed sparse format data.
#[derive(Debug)] #[derive(Debug)]
pub struct SparseFormatError { pub struct SparseFormatError {
kind: SparseFormatErrorKind, kind: SparseFormatErrorKind,
// Currently we only use an underlying error for generating the `Display` impl // Currently we only use an underlying error for generating the `Display` impl
error: Box<dyn Error> error: Box<dyn Error>,
} }
impl SparseFormatError { impl SparseFormatError {
@ -170,10 +170,7 @@ impl SparseFormatError {
} }
pub(crate) fn from_kind_and_error(kind: SparseFormatErrorKind, error: Box<dyn Error>) -> Self { pub(crate) fn from_kind_and_error(kind: SparseFormatErrorKind, error: Box<dyn Error>) -> Self {
Self { Self { kind, error }
kind,
error
}
} }
/// Helper functionality for more conveniently creating errors. /// Helper functionality for more conveniently creating errors.
@ -221,7 +218,7 @@ pub enum SparseEntry<'a, T> {
/// is explicitly stored (a so-called "explicit zero"). /// is explicitly stored (a so-called "explicit zero").
NonZero(&'a T), NonZero(&'a T),
/// The entry is implicitly zero, i.e. it is not explicitly stored. /// The entry is implicitly zero, i.e. it is not explicitly stored.
Zero Zero,
} }
impl<'a, T: Clone + Zero> SparseEntry<'a, T> { impl<'a, T: Clone + Zero> SparseEntry<'a, T> {
@ -232,7 +229,7 @@ impl<'a, T: Clone + Zero> SparseEntry<'a, T> {
pub fn to_value(self) -> T { pub fn to_value(self) -> T {
match self { match self {
SparseEntry::NonZero(value) => value.clone(), SparseEntry::NonZero(value) => value.clone(),
SparseEntry::Zero => T::zero() SparseEntry::Zero => T::zero(),
} }
} }
} }
@ -248,7 +245,7 @@ pub enum SparseEntryMut<'a, T> {
/// is explicitly stored (a so-called "explicit zero"). /// is explicitly stored (a so-called "explicit zero").
NonZero(&'a mut T), NonZero(&'a mut T),
/// The entry is implicitly zero i.e. it is not explicitly stored. /// The entry is implicitly zero i.e. it is not explicitly stored.
Zero Zero,
} }
impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> { impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
@ -259,7 +256,7 @@ impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
pub fn to_value(self) -> T { pub fn to_value(self) -> T {
match self { match self {
SparseEntryMut::NonZero(value) => value.clone(), SparseEntryMut::NonZero(value) => value.clone(),
SparseEntryMut::Zero => T::zero() SparseEntryMut::Zero => T::zero(),
} }
} }
} }

View File

@ -1,9 +1,9 @@
//! Implements core traits for use with `matrixcompare`. //! Implements core traits for use with `matrixcompare`.
use crate::csr::CsrMatrix; use crate::coo::CooMatrix;
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use crate::csr::CsrMatrix;
use matrixcompare_core; use matrixcompare_core;
use matrixcompare_core::{Access, SparseAccess}; use matrixcompare_core::{Access, SparseAccess};
use crate::coo::CooMatrix;
macro_rules! impl_matrix_for_csr_csc { macro_rules! impl_matrix_for_csr_csc {
($MatrixType:ident) => { ($MatrixType:ident) => {
@ -13,7 +13,9 @@ macro_rules! impl_matrix_for_csr_csc {
} }
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> { fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
self.triplet_iter().map(|(i, j, v)| (i, j, v.clone())).collect() self.triplet_iter()
.map(|(i, j, v)| (i, j, v.clone()))
.collect()
} }
} }
@ -30,7 +32,7 @@ macro_rules! impl_matrix_for_csr_csc {
Access::Sparse(self) Access::Sparse(self)
} }
} }
} };
} }
impl_matrix_for_csr_csc!(CsrMatrix); impl_matrix_for_csr_csc!(CsrMatrix);
@ -42,7 +44,9 @@ impl<T: Clone> SparseAccess<T> for CooMatrix<T> {
} }
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> { fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
self.triplet_iter().map(|(i, j, v)| (i, j, v.clone())).collect() self.triplet_iter()
.map(|(i, j, v)| (i, j, v.clone()))
.collect()
} }
} }

View File

@ -1,15 +1,20 @@
use crate::csr::CsrMatrix;
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use crate::csr::CsrMatrix;
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg}; use crate::ops::serial::{
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_csr_pattern, spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense, spmm_csc_pattern}; spadd_csc_prealloc, spadd_csr_prealloc, spadd_pattern, spmm_csc_dense, spmm_csc_pattern,
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, MatrixMN, Dim, spmm_csc_prealloc, spmm_csr_dense, spmm_csr_pattern, spmm_csr_prealloc,
Dynamic, DefaultAllocator, U1}; };
use nalgebra::allocator::{Allocator}; use crate::ops::Op;
use nalgebra::constraint::{DimEq, ShapeConstraint}; use nalgebra::allocator::Allocator;
use num_traits::{Zero, One};
use crate::ops::{Op};
use nalgebra::base::storage::Storage; use nalgebra::base::storage::Storage;
use nalgebra::constraint::{DimEq, ShapeConstraint};
use nalgebra::{
ClosedAdd, ClosedDiv, ClosedMul, ClosedSub, DefaultAllocator, Dim, Dynamic, Matrix, MatrixMN,
Scalar, U1,
};
use num_traits::{One, Zero};
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Neg, Sub};
/// Helper macro for implementing binary operators for different matrix types /// Helper macro for implementing binary operators for different matrix types
/// See below for usage. /// See below for usage.
@ -188,7 +193,7 @@ macro_rules! impl_neg {
($matrix_type:ident) => { ($matrix_type:ident) => {
impl<T> Neg for $matrix_type<T> impl<T> Neg for $matrix_type<T>
where where
T: Scalar + Neg<Output=T> T: Scalar + Neg<Output = T>,
{ {
type Output = $matrix_type<T>; type Output = $matrix_type<T>;
@ -202,7 +207,7 @@ macro_rules! impl_neg {
impl<'a, T> Neg for &'a $matrix_type<T> impl<'a, T> Neg for &'a $matrix_type<T>
where where
T: Scalar + Neg<Output=T> T: Scalar + Neg<Output = T>,
{ {
type Output = $matrix_type<T>; type Output = $matrix_type<T>;
@ -214,7 +219,7 @@ macro_rules! impl_neg {
-self.clone() -self.clone()
} }
} }
} };
} }
impl_neg!(CsrMatrix); impl_neg!(CsrMatrix);

View File

@ -148,13 +148,14 @@ impl<T> Op<T> {
pub fn as_ref(&self) -> Op<&T> { pub fn as_ref(&self) -> Op<&T> {
match self { match self {
Op::NoOp(obj) => Op::NoOp(&obj), Op::NoOp(obj) => Op::NoOp(&obj),
Op::Transpose(obj) => Op::Transpose(&obj) Op::Transpose(obj) => Op::Transpose(&obj),
} }
} }
/// Converts the underlying data type. /// Converts the underlying data type.
pub fn convert<U>(self) -> Op<U> pub fn convert<U>(self) -> Op<U>
where T: Into<U> where
T: Into<U>,
{ {
self.map_same_op(T::into) self.map_same_op(T::into)
} }
@ -163,7 +164,7 @@ impl<T> Op<T> {
pub fn map_same_op<U, F: FnOnce(T) -> U>(self, f: F) -> Op<U> { pub fn map_same_op<U, F: FnOnce(T) -> U>(self, f: F) -> Op<U> {
match self { match self {
Op::NoOp(obj) => Op::NoOp(f(obj)), Op::NoOp(obj) => Op::NoOp(f(obj)),
Op::Transpose(obj) => Op::Transpose(f(obj)) Op::Transpose(obj) => Op::Transpose(f(obj)),
} }
} }
@ -181,7 +182,7 @@ impl<T> Op<T> {
pub fn transposed(self) -> Self { pub fn transposed(self) -> Self {
match self { match self {
Op::NoOp(obj) => Op::Transpose(obj), Op::NoOp(obj) => Op::Transpose(obj),
Op::Transpose(obj) => Op::NoOp(obj) Op::Transpose(obj) => Op::NoOp(obj),
} }
} }
} }
@ -191,4 +192,3 @@ impl<T> From<T> for Op<T> {
Self::NoOp(obj) Self::NoOp(obj)
} }
} }

View File

@ -1,14 +1,15 @@
use crate::cs::CsMatrix; use crate::cs::CsMatrix;
use crate::ops::serial::{OperationError, OperationErrorKind};
use crate::ops::Op; use crate::ops::Op;
use crate::ops::serial::{OperationErrorKind, OperationError};
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice};
use num_traits::{Zero, One};
use crate::SparseEntryMut; use crate::SparseEntryMut;
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
use num_traits::{One, Zero};
fn spmm_cs_unexpected_entry() -> OperationError { fn spmm_cs_unexpected_entry() -> OperationError {
OperationError::from_kind_and_message( OperationError::from_kind_and_message(
OperationErrorKind::InvalidPattern, OperationErrorKind::InvalidPattern,
String::from("Found unexpected entry that is not present in `c`.")) String::from("Found unexpected entry that is not present in `c`."),
)
} }
/// Helper functionality for implementing CSR/CSC SPMM. /// Helper functionality for implementing CSR/CSC SPMM.
@ -24,10 +25,10 @@ pub fn spmm_cs_prealloc<T>(
c: &mut CsMatrix<T>, c: &mut CsMatrix<T>,
alpha: T, alpha: T,
a: &CsMatrix<T>, a: &CsMatrix<T>,
b: &CsMatrix<T>) b: &CsMatrix<T>,
-> Result<(), OperationError> ) -> Result<(), OperationError>
where where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{ {
for i in 0..c.pattern().major_dim() { for i in 0..c.pattern().major_dim() {
let a_lane_i = a.get_lane(i).unwrap(); let a_lane_i = a.get_lane(i).unwrap();
@ -42,7 +43,8 @@ pub fn spmm_cs_prealloc<T>(
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone(); let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
for (j, b_kj) in b_lane_k.minor_indices().iter().zip(b_lane_k.values()) { for (j, b_kj) in b_lane_k.minor_indices().iter().zip(b_lane_k.values()) {
// Determine the location in C to append the value // Determine the location in C to append the value
let (c_local_idx, _) = c_lane_i_cols.iter() let (c_local_idx, _) = c_lane_i_cols
.iter()
.enumerate() .enumerate()
.find(|(_, c_col)| *c_col == j) .find(|(_, c_col)| *c_col == j)
.ok_or_else(spmm_cs_unexpected_entry)?; .ok_or_else(spmm_cs_unexpected_entry)?;
@ -60,17 +62,19 @@ pub fn spmm_cs_prealloc<T>(
fn spadd_cs_unexpected_entry() -> OperationError { fn spadd_cs_unexpected_entry() -> OperationError {
OperationError::from_kind_and_message( OperationError::from_kind_and_message(
OperationErrorKind::InvalidPattern, OperationErrorKind::InvalidPattern,
String::from("Found entry in `op(a)` that is not present in `c`.")) String::from("Found entry in `op(a)` that is not present in `c`."),
)
} }
/// Helper functionality for implementing CSR/CSC SPADD. /// Helper functionality for implementing CSR/CSC SPADD.
pub fn spadd_cs_prealloc<T>(beta: T, pub fn spadd_cs_prealloc<T>(
beta: T,
c: &mut CsMatrix<T>, c: &mut CsMatrix<T>,
alpha: T, alpha: T,
a: Op<&CsMatrix<T>>) a: Op<&CsMatrix<T>>,
-> Result<(), OperationError> ) -> Result<(), OperationError>
where where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{ {
match a { match a {
Op::NoOp(a) => { Op::NoOp(a) => {
@ -88,7 +92,8 @@ pub fn spadd_cs_prealloc<T>(beta: T,
// TODO: Use exponential search instead of linear search. // TODO: Use exponential search instead of linear search.
// If C has substantially more entries in the row than A, then a line search // If C has substantially more entries in the row than A, then a line search
// will needlessly visit many entries in C. // will needlessly visit many entries in C.
let (c_idx, _) = c_minors.iter() let (c_idx, _) = c_minors
.iter()
.enumerate() .enumerate()
.find(|(_, c_col)| *c_col == a_col) .find(|(_, c_col)| *c_col == a_col)
.ok_or_else(spadd_cs_unexpected_entry)?; .ok_or_else(spadd_cs_unexpected_entry)?;
@ -110,7 +115,7 @@ pub fn spadd_cs_prealloc<T>(beta: T,
let a_val = a_val.inlined_clone(); let a_val = a_val.inlined_clone();
let alpha = alpha.inlined_clone(); let alpha = alpha.inlined_clone();
match c.get_entry_mut(j, i).unwrap() { match c.get_entry_mut(j, i).unwrap() {
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val } SparseEntryMut::NonZero(c_ji) => *c_ji += alpha * a_val,
SparseEntryMut::Zero => return Err(spadd_cs_unexpected_entry()), SparseEntryMut::Zero => return Err(spadd_cs_unexpected_entry()),
} }
} }
@ -124,13 +129,14 @@ pub fn spadd_cs_prealloc<T>(beta: T,
/// ///
/// The implementation essentially assumes that `a` is a CSR matrix. To use it with CSC matrices, /// The implementation essentially assumes that `a` is a CSR matrix. To use it with CSC matrices,
/// the transposed operation must be specified for the CSC matrix. /// the transposed operation must be specified for the CSC matrix.
pub fn spmm_cs_dense<T>(beta: T, pub fn spmm_cs_dense<T>(
beta: T,
mut c: DMatrixSliceMut<T>, mut c: DMatrixSliceMut<T>,
alpha: T, alpha: T,
a: Op<&CsMatrix<T>>, a: Op<&CsMatrix<T>>,
b: Op<DMatrixSlice<T>>) b: Op<DMatrixSlice<T>>,
where ) where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{ {
match a { match a {
Op::NoOp(a) => { Op::NoOp(a) => {
@ -139,17 +145,17 @@ pub fn spmm_cs_dense<T>(beta: T,
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.lane_iter()) { for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.lane_iter()) {
let mut dot_ij = T::zero(); let mut dot_ij = T::zero();
for (&k, a_ik) in a_row_i.minor_indices().iter().zip(a_row_i.values()) { for (&k, a_ik) in a_row_i.minor_indices().iter().zip(a_row_i.values()) {
let b_contrib = let b_contrib = match b {
match b {
Op::NoOp(ref b) => b.index((k, j)), Op::NoOp(ref b) => b.index((k, j)),
Op::Transpose(ref b) => b.index((j, k)) Op::Transpose(ref b) => b.index((j, k)),
}; };
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone(); dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
} }
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij; *c_ij = beta.inlined_clone() * c_ij.inlined_clone()
+ alpha.inlined_clone() * dot_ij;
}
} }
} }
},
Op::Transpose(a) => { Op::Transpose(a) => {
// In this case, we have to pre-multiply C by beta // In this case, we have to pre-multiply C by beta
c *= beta; c *= beta;
@ -165,17 +171,16 @@ pub fn spmm_cs_dense<T>(beta: T,
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) { for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone(); *c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
} }
}, }
Op::Transpose(ref b) => { Op::Transpose(ref b) => {
let b_col_k = b.column(k); let b_col_k = b.column(k);
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) { for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone(); *c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
} }
},
} }
} }
} }
},
} }
} }
}
}

View File

@ -1,9 +1,9 @@
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use crate::ops::Op; use crate::ops::serial::cs::{spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc};
use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc};
use crate::ops::serial::{OperationError, OperationErrorKind}; use crate::ops::serial::{OperationError, OperationErrorKind};
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice, RealField}; use crate::ops::Op;
use num_traits::{Zero, One}; use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, RealField, Scalar};
use num_traits::{One, Zero};
use std::borrow::Cow; use std::borrow::Cow;
@ -12,25 +12,27 @@ use std::borrow::Cow;
/// # Panics /// # Panics
/// ///
/// Panics if the dimensions of the matrices involved are not compatible with the expression. /// Panics if the dimensions of the matrices involved are not compatible with the expression.
pub fn spmm_csc_dense<'a, T>(beta: T, pub fn spmm_csc_dense<'a, T>(
beta: T,
c: impl Into<DMatrixSliceMut<'a, T>>, c: impl Into<DMatrixSliceMut<'a, T>>,
alpha: T, alpha: T,
a: Op<&CscMatrix<T>>, a: Op<&CscMatrix<T>>,
b: Op<impl Into<DMatrixSlice<'a, T>>>) b: Op<impl Into<DMatrixSlice<'a, T>>>,
where ) where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{ {
let b = b.convert(); let b = b.convert();
spmm_csc_dense_(beta, c.into(), alpha, a, b) spmm_csc_dense_(beta, c.into(), alpha, a, b)
} }
fn spmm_csc_dense_<T>(beta: T, fn spmm_csc_dense_<T>(
beta: T,
c: DMatrixSliceMut<T>, c: DMatrixSliceMut<T>,
alpha: T, alpha: T,
a: Op<&CscMatrix<T>>, a: Op<&CscMatrix<T>>,
b: Op<DMatrixSlice<T>>) b: Op<DMatrixSlice<T>>,
where ) where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{ {
assert_compatible_spmm_dims!(c, a, b); assert_compatible_spmm_dims!(c, a, b);
// Need to interpret matrix as transposed since the spmm_cs_dense function assumes CSR layout // Need to interpret matrix as transposed since the spmm_cs_dense function assumes CSR layout
@ -46,19 +48,19 @@ fn spmm_csc_dense_<T>(beta: T,
/// # Panics /// # Panics
/// ///
/// Panics if the dimensions of the matrices involved are not compatible with the expression. /// Panics if the dimensions of the matrices involved are not compatible with the expression.
pub fn spadd_csc_prealloc<T>(beta: T, pub fn spadd_csc_prealloc<T>(
beta: T,
c: &mut CscMatrix<T>, c: &mut CscMatrix<T>,
alpha: T, alpha: T,
a: Op<&CscMatrix<T>>) a: Op<&CscMatrix<T>>,
-> Result<(), OperationError> ) -> Result<(), OperationError>
where where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{ {
assert_compatible_spadd_dims!(c, a); assert_compatible_spadd_dims!(c, a);
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs)) spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
} }
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`. /// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
/// ///
/// # Errors /// # Errors
@ -74,10 +76,10 @@ pub fn spmm_csc_prealloc<T>(
c: &mut CscMatrix<T>, c: &mut CscMatrix<T>,
alpha: T, alpha: T,
a: Op<&CscMatrix<T>>, a: Op<&CscMatrix<T>>,
b: Op<&CscMatrix<T>>) b: Op<&CscMatrix<T>>,
-> Result<(), OperationError> ) -> Result<(), OperationError>
where where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{ {
assert_compatible_spmm_dims!(c, a, b); assert_compatible_spmm_dims!(c, a, b);
@ -87,7 +89,7 @@ pub fn spmm_csc_prealloc<T>(
(NoOp(ref a), NoOp(ref b)) => { (NoOp(ref a), NoOp(ref b)) => {
// Note: We have to reverse the order for CSC matrices // Note: We have to reverse the order for CSC matrices
spmm_cs_prealloc(beta, &mut c.cs, alpha, &b.cs, &a.cs) spmm_cs_prealloc(beta, &mut c.cs, alpha, &b.cs, &a.cs)
}, }
_ => { _ => {
// Currently we handle transposition by explicitly precomputing transposed matrices // Currently we handle transposition by explicitly precomputing transposed matrices
// and calling the operation again without transposition // and calling the operation again without transposition
@ -99,7 +101,9 @@ pub fn spmm_csc_prealloc<T>(
(NoOp(_), NoOp(_)) => unreachable!(), (NoOp(_), NoOp(_)) => unreachable!(),
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)), (Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())), (NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
(Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose())) (Transpose(ref a), Transpose(ref b)) => {
(Owned(a.transpose()), Owned(b.transpose()))
}
} }
}; };
@ -121,13 +125,20 @@ pub fn spmm_csc_prealloc<T>(
/// Panics if `L` is not square, or if `L` and `B` are not dimensionally compatible. /// Panics if `L` is not square, or if `L` and `B` are not dimensionally compatible.
pub fn spsolve_csc_lower_triangular<'a, T: RealField>( pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
l: Op<&CscMatrix<T>>, l: Op<&CscMatrix<T>>,
b: impl Into<DMatrixSliceMut<'a, T>>) b: impl Into<DMatrixSliceMut<'a, T>>,
-> Result<(), OperationError> ) -> Result<(), OperationError> {
{
let b = b.into(); let b = b.into();
let l_matrix = l.into_inner(); let l_matrix = l.into_inner();
assert_eq!(l_matrix.nrows(), l_matrix.ncols(), "Matrix must be square for triangular solve."); assert_eq!(
assert_eq!(l_matrix.nrows(), b.nrows(), "Dimension mismatch in sparse lower triangular solver."); l_matrix.nrows(),
l_matrix.ncols(),
"Matrix must be square for triangular solve."
);
assert_eq!(
l_matrix.nrows(),
b.nrows(),
"Dimension mismatch in sparse lower triangular solver."
);
match l { match l {
Op::NoOp(a) => spsolve_csc_lower_triangular_no_transpose(a, b), Op::NoOp(a) => spsolve_csc_lower_triangular_no_transpose(a, b),
Op::Transpose(a) => spsolve_csc_lower_triangular_transpose(a, b), Op::Transpose(a) => spsolve_csc_lower_triangular_transpose(a, b),
@ -136,9 +147,8 @@ pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
fn spsolve_csc_lower_triangular_no_transpose<T: RealField>( fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
l: &CscMatrix<T>, l: &CscMatrix<T>,
b: DMatrixSliceMut<T>) b: DMatrixSliceMut<T>,
-> Result<(), OperationError> ) -> Result<(), OperationError> {
{
let mut x = b; let mut x = b;
// Solve column-by-column // Solve column-by-column
@ -187,14 +197,16 @@ fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
fn spsolve_encountered_zero_diagonal() -> Result<(), OperationError> { fn spsolve_encountered_zero_diagonal() -> Result<(), OperationError> {
let message = "Matrix contains at least one diagonal entry that is zero."; let message = "Matrix contains at least one diagonal entry that is zero.";
Err(OperationError::from_kind_and_message(OperationErrorKind::Singular, String::from(message))) Err(OperationError::from_kind_and_message(
OperationErrorKind::Singular,
String::from(message),
))
} }
fn spsolve_csc_lower_triangular_transpose<T: RealField>( fn spsolve_csc_lower_triangular_transpose<T: RealField>(
l: &CscMatrix<T>, l: &CscMatrix<T>,
b: DMatrixSliceMut<T>) b: DMatrixSliceMut<T>,
-> Result<(), OperationError> ) -> Result<(), OperationError> {
{
let mut x = b; let mut x = b;
// Solve column-by-column // Solve column-by-column

View File

@ -1,31 +1,33 @@
use crate::csr::CsrMatrix; use crate::csr::CsrMatrix;
use crate::ops::{Op}; use crate::ops::serial::cs::{spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc};
use crate::ops::serial::{OperationError}; use crate::ops::serial::OperationError;
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut}; use crate::ops::Op;
use num_traits::{Zero, One}; use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
use num_traits::{One, Zero};
use std::borrow::Cow; use std::borrow::Cow;
use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc};
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`. /// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
pub fn spmm_csr_dense<'a, T>(beta: T, pub fn spmm_csr_dense<'a, T>(
beta: T,
c: impl Into<DMatrixSliceMut<'a, T>>, c: impl Into<DMatrixSliceMut<'a, T>>,
alpha: T, alpha: T,
a: Op<&CsrMatrix<T>>, a: Op<&CsrMatrix<T>>,
b: Op<impl Into<DMatrixSlice<'a, T>>>) b: Op<impl Into<DMatrixSlice<'a, T>>>,
where ) where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{ {
let b = b.convert(); let b = b.convert();
spmm_csr_dense_(beta, c.into(), alpha, a, b) spmm_csr_dense_(beta, c.into(), alpha, a, b)
} }
fn spmm_csr_dense_<T>(beta: T, fn spmm_csr_dense_<T>(
beta: T,
c: DMatrixSliceMut<T>, c: DMatrixSliceMut<T>,
alpha: T, alpha: T,
a: Op<&CsrMatrix<T>>, a: Op<&CsrMatrix<T>>,
b: Op<DMatrixSlice<T>>) b: Op<DMatrixSlice<T>>,
where ) where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{ {
assert_compatible_spmm_dims!(c, a, b); assert_compatible_spmm_dims!(c, a, b);
spmm_cs_dense(beta, c, alpha, a.map_same_op(|a| &a.cs), b) spmm_cs_dense(beta, c, alpha, a.map_same_op(|a| &a.cs), b)
@ -41,13 +43,14 @@ where
/// # Panics /// # Panics
/// ///
/// Panics if the dimensions of the matrices involved are not compatible with the expression. /// Panics if the dimensions of the matrices involved are not compatible with the expression.
pub fn spadd_csr_prealloc<T>(beta: T, pub fn spadd_csr_prealloc<T>(
beta: T,
c: &mut CsrMatrix<T>, c: &mut CsrMatrix<T>,
alpha: T, alpha: T,
a: Op<&CsrMatrix<T>>) a: Op<&CsrMatrix<T>>,
-> Result<(), OperationError> ) -> Result<(), OperationError>
where where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{ {
assert_compatible_spadd_dims!(c, a); assert_compatible_spadd_dims!(c, a);
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs)) spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
@ -67,19 +70,17 @@ pub fn spmm_csr_prealloc<T>(
c: &mut CsrMatrix<T>, c: &mut CsrMatrix<T>,
alpha: T, alpha: T,
a: Op<&CsrMatrix<T>>, a: Op<&CsrMatrix<T>>,
b: Op<&CsrMatrix<T>>) b: Op<&CsrMatrix<T>>,
-> Result<(), OperationError> ) -> Result<(), OperationError>
where where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{ {
assert_compatible_spmm_dims!(c, a, b); assert_compatible_spmm_dims!(c, a, b);
use Op::{NoOp, Transpose}; use Op::{NoOp, Transpose};
match (&a, &b) { match (&a, &b) {
(NoOp(ref a), NoOp(ref b)) => { (NoOp(ref a), NoOp(ref b)) => spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs),
spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs)
},
_ => { _ => {
// Currently we handle transposition by explicitly precomputing transposed matrices // Currently we handle transposition by explicitly precomputing transposed matrices
// and calling the operation again without transposition // and calling the operation again without transposition
@ -93,7 +94,9 @@ where
(NoOp(_), NoOp(_)) => unreachable!(), (NoOp(_), NoOp(_)) => unreachable!(),
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)), (Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())), (NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
(Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose())) (Transpose(ref a), Transpose(ref b)) => {
(Owned(a.transpose()), Owned(b.transpose()))
}
} }
}; };
@ -101,4 +104,3 @@ where
} }
} }
} }

View File

@ -10,33 +10,31 @@
#[macro_use] #[macro_use]
macro_rules! assert_compatible_spmm_dims { macro_rules! assert_compatible_spmm_dims {
($c:expr, $a:expr, $b:expr) => { ($c:expr, $a:expr, $b:expr) => {{
{
use crate::ops::Op::{NoOp, Transpose}; use crate::ops::Op::{NoOp, Transpose};
match (&$a, &$b) { match (&$a, &$b) {
(NoOp(ref a), NoOp(ref b)) => { (NoOp(ref a), NoOp(ref b)) => {
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()"); assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()"); assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
assert_eq!(a.ncols(), b.nrows(), "A.ncols() != B.nrows()"); assert_eq!(a.ncols(), b.nrows(), "A.ncols() != B.nrows()");
}, }
(Transpose(ref a), NoOp(ref b)) => { (Transpose(ref a), NoOp(ref b)) => {
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()"); assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()"); assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
assert_eq!(a.nrows(), b.nrows(), "A.nrows() != B.nrows()"); assert_eq!(a.nrows(), b.nrows(), "A.nrows() != B.nrows()");
}, }
(NoOp(ref a), Transpose(ref b)) => { (NoOp(ref a), Transpose(ref b)) => {
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()"); assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()"); assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
assert_eq!(a.ncols(), b.ncols(), "A.ncols() != B.ncols()"); assert_eq!(a.ncols(), b.ncols(), "A.ncols() != B.ncols()");
}, }
(Transpose(ref a), Transpose(ref b)) => { (Transpose(ref a), Transpose(ref b)) => {
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()"); assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()"); assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
assert_eq!(a.nrows(), b.ncols(), "A.nrows() != B.ncols()"); assert_eq!(a.nrows(), b.ncols(), "A.nrows() != B.ncols()");
} }
} }
} }};
}
} }
#[macro_use] #[macro_use]
@ -47,32 +45,31 @@ macro_rules! assert_compatible_spadd_dims {
Op::NoOp(a) => { Op::NoOp(a) => {
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()"); assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
assert_eq!($c.ncols(), a.ncols(), "C.ncols() != A.ncols()"); assert_eq!($c.ncols(), a.ncols(), "C.ncols() != A.ncols()");
}, }
Op::Transpose(a) => { Op::Transpose(a) => {
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()"); assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
assert_eq!($c.ncols(), a.nrows(), "C.ncols() != A.nrows()"); assert_eq!($c.ncols(), a.nrows(), "C.ncols() != A.nrows()");
} }
} }
};
}
} }
mod cs;
mod csc; mod csc;
mod csr; mod csr;
mod pattern; mod pattern;
mod cs;
pub use csc::*; pub use csc::*;
pub use csr::*; pub use csr::*;
pub use pattern::*; pub use pattern::*;
use std::fmt::Formatter;
use std::fmt; use std::fmt;
use std::fmt::Formatter;
/// A description of the error that occurred during an arithmetic operation. /// A description of the error that occurred during an arithmetic operation.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct OperationError { pub struct OperationError {
error_kind: OperationErrorKind, error_kind: OperationErrorKind,
message: String message: String,
} }
/// The different kinds of operation errors that may occur. /// The different kinds of operation errors that may occur.
@ -92,7 +89,10 @@ pub enum OperationErrorKind {
impl OperationError { impl OperationError {
fn from_kind_and_message(error_type: OperationErrorKind, message: String) -> Self { fn from_kind_and_message(error_type: OperationErrorKind, message: String) -> Self {
Self { error_kind: error_type, message } Self {
error_kind: error_type,
message,
}
} }
/// The operation error kind. /// The operation error kind.
@ -110,8 +110,12 @@ impl fmt::Display for OperationError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Sparse matrix operation error: ")?; write!(f, "Sparse matrix operation error: ")?;
match self.kind() { match self.kind() {
OperationErrorKind::InvalidPattern => { write!(f, "InvalidPattern")?; } OperationErrorKind::InvalidPattern => {
OperationErrorKind::Singular => { write!(f, "Singular")?; } write!(f, "InvalidPattern")?;
}
OperationErrorKind::Singular => {
write!(f, "Singular")?;
}
} }
write!(f, " Message: {}", self.message) write!(f, " Message: {}", self.message)
} }

View File

@ -12,11 +12,17 @@ use std::iter;
/// # Panics /// # Panics
/// ///
/// Panics if the patterns do not have the same major and minor dimensions. /// Panics if the patterns do not have the same major and minor dimensions.
pub fn spadd_pattern(a: &SparsityPattern, pub fn spadd_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
b: &SparsityPattern) -> SparsityPattern assert_eq!(
{ a.major_dim(),
assert_eq!(a.major_dim(), b.major_dim(), "Patterns must have identical major dimensions."); b.major_dim(),
assert_eq!(a.minor_dim(), b.minor_dim(), "Patterns must have identical minor dimensions."); "Patterns must have identical major dimensions."
);
assert_eq!(
a.minor_dim(),
b.minor_dim(),
"Patterns must have identical minor dimensions."
);
let mut offsets = Vec::new(); let mut offsets = Vec::new();
let mut indices = Vec::new(); let mut indices = Vec::new();
@ -33,8 +39,7 @@ pub fn spadd_pattern(a: &SparsityPattern,
} }
// TODO: Consider circumventing format checks? (requires unsafe, should benchmark first) // TODO: Consider circumventing format checks? (requires unsafe, should benchmark first)
SparsityPattern::try_from_offsets_and_indices( SparsityPattern::try_from_offsets_and_indices(a.major_dim(), a.minor_dim(), offsets, indices)
a.major_dim(), a.minor_dim(), offsets, indices)
.expect("Internal error: Pattern must be valid by definition") .expect("Internal error: Pattern must be valid by definition")
} }
@ -66,7 +71,11 @@ pub fn spmm_csc_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPat
/// Panics if the patterns, when interpreted as CSR patterns, are not compatible for /// Panics if the patterns, when interpreted as CSR patterns, are not compatible for
/// matrix multiplication. /// matrix multiplication.
pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern { pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
assert_eq!(a.minor_dim(), b.major_dim(), "a and b must have compatible dimensions"); assert_eq!(
a.minor_dim(),
b.major_dim(),
"a and b must have compatible dimensions"
);
let mut offsets = Vec::new(); let mut offsets = Vec::new();
let mut indices = Vec::new(); let mut indices = Vec::new();
@ -110,8 +119,10 @@ pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPat
/// Iterate over the union of the two sets represented by sorted slices /// Iterate over the union of the two sets represented by sorted slices
/// (with unique elements) /// (with unique elements)
fn iterate_union<'a>(mut sorted_a: &'a [usize], fn iterate_union<'a>(
mut sorted_b: &'a [usize]) -> impl Iterator<Item=usize> + 'a { mut sorted_a: &'a [usize],
mut sorted_b: &'a [usize],
) -> impl Iterator<Item = usize> + 'a {
iter::from_fn(move || { iter::from_fn(move || {
if let (Some(a_item), Some(b_item)) = (sorted_a.first(), sorted_b.first()) { if let (Some(a_item), Some(b_item)) = (sorted_a.first(), sorted_b.first()) {
let item = if a_item < b_item { let item = if a_item < b_item {

View File

@ -1,8 +1,8 @@
//! Sparsity patterns for CSR and CSC matrices. //! Sparsity patterns for CSR and CSC matrices.
use crate::SparseFormatError;
use std::fmt;
use std::error::Error;
use crate::cs::transpose_cs; use crate::cs::transpose_cs;
use crate::SparseFormatError;
use std::error::Error;
use std::fmt;
/// A representation of the sparsity pattern of a CSR or CSC matrix. /// A representation of the sparsity pattern of a CSR or CSC matrix.
/// ///
@ -236,12 +236,15 @@ impl SparsityPattern {
self.minor_dim(), self.minor_dim(),
self.major_offsets(), self.major_offsets(),
self.minor_indices(), self.minor_indices(),
&values); &values,
);
// TODO: Skip checks // TODO: Skip checks
Self::try_from_offsets_and_indices(self.minor_dim(), Self::try_from_offsets_and_indices(
self.minor_dim(),
self.major_dim(), self.major_dim(),
new_offsets, new_offsets,
new_indices) new_indices,
)
.expect("Internal error: Transpose should never fail.") .expect("Internal error: Transpose should never fail.")
} }
} }
@ -275,22 +278,25 @@ pub enum SparsityPatternFormatError {
impl From<SparsityPatternFormatError> for SparseFormatError { impl From<SparsityPatternFormatError> for SparseFormatError {
fn from(err: SparsityPatternFormatError) -> Self { fn from(err: SparsityPatternFormatError) -> Self {
use SparsityPatternFormatError::*;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use crate::SparseFormatErrorKind; use crate::SparseFormatErrorKind;
use crate::SparseFormatErrorKind::*; use crate::SparseFormatErrorKind::*;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparsityPatternFormatError::*;
match err { match err {
InvalidOffsetArrayLength InvalidOffsetArrayLength
| InvalidOffsetFirstLast | InvalidOffsetFirstLast
| NonmonotonicOffsets | NonmonotonicOffsets
| NonmonotonicMinorIndices | NonmonotonicMinorIndices => {
=> SparseFormatError::from_kind_and_error(InvalidStructure, Box::from(err)), SparseFormatError::from_kind_and_error(InvalidStructure, Box::from(err))
MinorIndexOutOfBounds }
=> SparseFormatError::from_kind_and_error(IndexOutOfBounds, MinorIndexOutOfBounds => {
Box::from(err)), SparseFormatError::from_kind_and_error(IndexOutOfBounds, Box::from(err))
PatternDuplicateEntry }
=> SparseFormatError::from_kind_and_error(SparseFormatErrorKind::DuplicateEntry, PatternDuplicateEntry => SparseFormatError::from_kind_and_error(
Box::from(err)), #[allow(unused_qualifications)]
SparseFormatErrorKind::DuplicateEntry,
Box::from(err),
),
} }
} }
} }
@ -300,22 +306,25 @@ impl fmt::Display for SparsityPatternFormatError {
match self { match self {
SparsityPatternFormatError::InvalidOffsetArrayLength => { SparsityPatternFormatError::InvalidOffsetArrayLength => {
write!(f, "Length of offset array is not equal to (major_dim + 1).") write!(f, "Length of offset array is not equal to (major_dim + 1).")
}, }
SparsityPatternFormatError::InvalidOffsetFirstLast => { SparsityPatternFormatError::InvalidOffsetFirstLast => {
write!(f, "First or last offset is incompatible with format.") write!(f, "First or last offset is incompatible with format.")
}, }
SparsityPatternFormatError::NonmonotonicOffsets => { SparsityPatternFormatError::NonmonotonicOffsets => {
write!(f, "Offsets are not monotonically increasing.") write!(f, "Offsets are not monotonically increasing.")
}, }
SparsityPatternFormatError::MinorIndexOutOfBounds => { SparsityPatternFormatError::MinorIndexOutOfBounds => {
write!(f, "A minor index is out of bounds.") write!(f, "A minor index is out of bounds.")
}, }
SparsityPatternFormatError::DuplicateEntry => { SparsityPatternFormatError::DuplicateEntry => {
write!(f, "Input data contains duplicate entries.") write!(f, "Input data contains duplicate entries.")
}, }
SparsityPatternFormatError::NonmonotonicMinorIndices => { SparsityPatternFormatError::NonmonotonicMinorIndices => {
write!(f, "Minor indices are not monotonically increasing within each lane.") write!(
}, f,
"Minor indices are not monotonically increasing within each lane."
)
}
} }
} }
} }
@ -340,7 +349,7 @@ impl<'a> SparsityPatternIter<'a> {
major_offsets: pattern.major_offsets(), major_offsets: pattern.major_offsets(),
minor_indices: pattern.minor_indices(), minor_indices: pattern.minor_indices(),
current_lane_idx: 0, current_lane_idx: 0,
remaining_minors_in_lane: minors_in_first_lane remaining_minors_in_lane: minors_in_first_lane,
} }
} }
} }
@ -375,11 +384,10 @@ impl<'a> Iterator for SparsityPatternIter<'a> {
let upper = self.major_offsets[self.current_lane_idx + 1]; let upper = self.major_offsets[self.current_lane_idx + 1];
if upper > lower { if upper > lower {
self.remaining_minors_in_lane = &self.minor_indices[(lower + 1)..upper]; self.remaining_minors_in_lane = &self.minor_indices[(lower + 1)..upper];
return Some((self.current_lane_idx, self.minor_indices[lower])) return Some((self.current_lane_idx, self.minor_indices[lower]));
} }
} }
} }
} }
} }
} }

View File

@ -11,20 +11,22 @@
mod proptest_patched; mod proptest_patched;
use crate::coo::CooMatrix; use crate::coo::CooMatrix;
use proptest::prelude::*; use crate::csc::CscMatrix;
use proptest::collection::{vec, hash_map, btree_set};
use nalgebra::{Scalar, Dim};
use std::cmp::min;
use std::iter::{repeat};
use proptest::sample::{Index};
use crate::csr::CsrMatrix; use crate::csr::CsrMatrix;
use crate::pattern::SparsityPattern; use crate::pattern::SparsityPattern;
use crate::csc::CscMatrix;
use nalgebra::proptest::DimRange; use nalgebra::proptest::DimRange;
use nalgebra::{Dim, Scalar};
use proptest::collection::{btree_set, hash_map, vec};
use proptest::prelude::*;
use proptest::sample::Index;
use std::cmp::min;
use std::iter::repeat;
fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize) fn dense_row_major_coord_strategy(
-> impl Strategy<Value=Vec<(usize, usize)>> nrows: usize,
{ ncols: usize,
nnz: usize,
) -> impl Strategy<Value = Vec<(usize, usize)>> {
assert!(nnz <= nrows * ncols); assert!(nnz <= nrows * ncols);
let mut booleans = vec![true; nnz]; let mut booleans = vec![true; nnz];
booleans.append(&mut vec![false; (nrows * ncols) - nnz]); booleans.append(&mut vec![false; (nrows * ncols) - nnz]);
@ -38,8 +40,7 @@ fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
// // Need to shuffle to make sure they are randomly distributed // // Need to shuffle to make sure they are randomly distributed
// .prop_shuffle() // .prop_shuffle()
proptest_patched::Shuffle(Just(booleans)) proptest_patched::Shuffle(Just(booleans)).prop_map(move |booleans| {
.prop_map(move |booleans| {
booleans booleans
.into_iter() .into_iter()
.enumerate() .enumerate()
@ -60,11 +61,12 @@ fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
/// A strategy for generating `nnz` triplets. /// A strategy for generating `nnz` triplets.
/// ///
/// This strategy should generally only be used when `nnz` is close to `nrows * ncols`. /// This strategy should generally only be used when `nnz` is close to `nrows * ncols`.
fn dense_triplet_strategy<T>(value_strategy: T, fn dense_triplet_strategy<T>(
value_strategy: T,
nrows: usize, nrows: usize,
ncols: usize, ncols: usize,
nnz: usize) nnz: usize,
-> impl Strategy<Value=Vec<(usize, usize, T::Value)>> ) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
where where
T: Strategy + Clone + 'static, T: Strategy + Clone + 'static,
T::Value: Scalar, T::Value: Scalar,
@ -100,13 +102,12 @@ where
}) })
// Assign values to each coordinate pair in order to generate a list of triplets // Assign values to each coordinate pair in order to generate a list of triplets
.prop_flat_map(move |coords| { .prop_flat_map(move |coords| {
vec![value_strategy.clone(); coords.len()] vec![value_strategy.clone(); coords.len()].prop_map(move |values| {
.prop_map(move |values| { coords
coords.clone().into_iter() .clone()
.into_iter()
.zip(values) .zip(values)
.map(|((i, j), v)| { .map(|((i, j), v)| (i, j, v))
(i, j, v)
})
.collect::<Vec<_>>() .collect::<Vec<_>>()
}) })
}) })
@ -116,11 +117,12 @@ where
/// ///
/// This strategy should generally only be used when `nnz << nrows * ncols`. If `nnz` is too /// This strategy should generally only be used when `nnz << nrows * ncols`. If `nnz` is too
/// close to `nrows * ncols` it may fail due to excessive rejected samples. /// close to `nrows * ncols` it may fail due to excessive rejected samples.
fn sparse_triplet_strategy<T>(value_strategy: T, fn sparse_triplet_strategy<T>(
value_strategy: T,
nrows: usize, nrows: usize,
ncols: usize, ncols: usize,
nnz: usize) nnz: usize,
-> impl Strategy<Value=Vec<(usize, usize, T::Value)>> ) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
where where
T: Strategy + Clone + 'static, T: Strategy + Clone + 'static,
T::Value: Scalar, T::Value: Scalar,
@ -131,10 +133,7 @@ fn sparse_triplet_strategy<T>(value_strategy: T,
let coord_strategy = (row_index_strategy, col_index_strategy); let coord_strategy = (row_index_strategy, col_index_strategy);
hash_map(coord_strategy, value_strategy.clone(), nnz) hash_map(coord_strategy, value_strategy.clone(), nnz)
.prop_map(|hash_map| { .prop_map(|hash_map| {
let triplets: Vec<_> = hash_map let triplets: Vec<_> = hash_map.into_iter().map(|((i, j), v)| (i, j, v)).collect();
.into_iter()
.map(|((i, j), v)| (i, j, v))
.collect();
triplets triplets
}) })
// Although order in the hash map is unspecified, it's not necessarily *random* // Although order in the hash map is unspecified, it's not necessarily *random*
@ -153,18 +152,23 @@ pub fn coo_no_duplicates<T>(
value_strategy: T, value_strategy: T,
rows: impl Into<DimRange>, rows: impl Into<DimRange>,
cols: impl Into<DimRange>, cols: impl Into<DimRange>,
max_nonzeros: usize) -> impl Strategy<Value=CooMatrix<T::Value>> max_nonzeros: usize,
) -> impl Strategy<Value = CooMatrix<T::Value>>
where where
T: Strategy + Clone + 'static, T: Strategy + Clone + 'static,
T::Value: Scalar, T::Value: Scalar,
{ {
(rows.into().to_range_inclusive(), cols.into().to_range_inclusive()) (
rows.into().to_range_inclusive(),
cols.into().to_range_inclusive(),
)
.prop_flat_map(move |(nrows, ncols)| { .prop_flat_map(move |(nrows, ncols)| {
let max_nonzeros = min(max_nonzeros, nrows * ncols); let max_nonzeros = min(max_nonzeros, nrows * ncols);
let size_range = 0..=max_nonzeros; let size_range = 0..=max_nonzeros;
let value_strategy = value_strategy.clone(); let value_strategy = value_strategy.clone();
size_range.prop_flat_map(move |nnz| { size_range
.prop_flat_map(move |nnz| {
let value_strategy = value_strategy.clone(); let value_strategy = value_strategy.clone();
if nnz as f64 > 0.10 * (nrows as f64) * (ncols as f64) { if nnz as f64 > 0.10 * (nrows as f64) * (ncols as f64) {
// If the number of nnz is sufficiently dense, then use the dense // If the number of nnz is sufficiently dense, then use the dense
@ -202,8 +206,8 @@ pub fn coo_with_duplicates<T>(
rows: impl Into<DimRange>, rows: impl Into<DimRange>,
cols: impl Into<DimRange>, cols: impl Into<DimRange>,
max_nonzeros: usize, max_nonzeros: usize,
max_duplicates: usize) max_duplicates: usize,
-> impl Strategy<Value=CooMatrix<T::Value>> ) -> impl Strategy<Value = CooMatrix<T::Value>>
where where
T: Strategy + Clone + 'static, T: Strategy + Clone + 'static,
T::Value: Scalar, T::Value: Scalar,
@ -212,7 +216,8 @@ where
let duplicate_strategy = vec((any::<Index>(), value_strategy.clone()), 0..=max_duplicates); let duplicate_strategy = vec((any::<Index>(), value_strategy.clone()), 0..=max_duplicates);
(coo_strategy, duplicate_strategy) (coo_strategy, duplicate_strategy)
.prop_flat_map(|(coo, duplicates)| { .prop_flat_map(|(coo, duplicates)| {
let mut triplets: Vec<(usize, usize, T::Value)> = coo.triplet_iter() let mut triplets: Vec<(usize, usize, T::Value)> = coo
.triplet_iter()
.map(|(i, j, v)| (i, j, v.clone())) .map(|(i, j, v)| (i, j, v.clone()))
.collect(); .collect();
if !triplets.is_empty() { if !triplets.is_empty() {
@ -238,7 +243,11 @@ where
}) })
} }
fn sparsity_pattern_from_row_major_coords<I>(nmajor: usize, nminor: usize, coords: I) -> SparsityPattern fn sparsity_pattern_from_row_major_coords<I>(
nmajor: usize,
nminor: usize,
coords: I,
) -> SparsityPattern
where where
I: Iterator<Item = (usize, usize)> + ExactSizeIterator, I: Iterator<Item = (usize, usize)> + ExactSizeIterator,
{ {
@ -248,7 +257,10 @@ where
offsets.push(0); offsets.push(0);
for (idx, (i, j)) in coords.enumerate() { for (idx, (i, j)) in coords.enumerate() {
assert!(i >= current_major); assert!(i >= current_major);
assert!(i < nmajor && j < nminor, "Generated coords are out of bounds"); assert!(
i < nmajor && j < nminor,
"Generated coords are out of bounds"
);
while current_major < i { while current_major < i {
offsets.push(idx); offsets.push(idx);
current_major += 1; current_major += 1;
@ -264,10 +276,7 @@ where
assert_eq!(offsets.first().unwrap(), &0); assert_eq!(offsets.first().unwrap(), &0);
assert_eq!(offsets.len(), nmajor + 1); assert_eq!(offsets.len(), nmajor + 1);
SparsityPattern::try_from_offsets_and_indices(nmajor, SparsityPattern::try_from_offsets_and_indices(nmajor, nminor, offsets, minors)
nminor,
offsets,
minors)
.expect("Internal error: Generated sparsity pattern is invalid") .expect("Internal error: Generated sparsity pattern is invalid")
} }
@ -275,14 +284,17 @@ where
pub fn sparsity_pattern( pub fn sparsity_pattern(
major_lanes: impl Into<DimRange>, major_lanes: impl Into<DimRange>,
minor_lanes: impl Into<DimRange>, minor_lanes: impl Into<DimRange>,
max_nonzeros: usize) max_nonzeros: usize,
-> impl Strategy<Value=SparsityPattern> ) -> impl Strategy<Value = SparsityPattern> {
{ (
(major_lanes.into().to_range_inclusive(), minor_lanes.into().to_range_inclusive()) major_lanes.into().to_range_inclusive(),
minor_lanes.into().to_range_inclusive(),
)
.prop_flat_map(move |(nmajor, nminor)| { .prop_flat_map(move |(nmajor, nminor)| {
let max_nonzeros = min(nmajor * nminor, max_nonzeros); let max_nonzeros = min(nmajor * nminor, max_nonzeros);
(Just(nmajor), Just(nminor), 0..=max_nonzeros) (Just(nmajor), Just(nminor), 0..=max_nonzeros)
}).prop_flat_map(move |(nmajor, nminor, nnz)| { })
.prop_flat_map(move |(nmajor, nminor, nnz)| {
if 10 * nnz < nmajor * nminor { if 10 * nnz < nmajor * nminor {
// If nnz is small compared to a dense matrix, then use a sparse sampling strategy // If nnz is small compared to a dense matrix, then use a sparse sampling strategy
btree_set((0..nmajor, 0..nminor), nnz) btree_set((0..nmajor, 0..nminor), nnz)
@ -297,24 +309,30 @@ pub fn sparsity_pattern(
.prop_map(move |coords| { .prop_map(move |coords| {
let coords = coords.into_iter(); let coords = coords.into_iter();
sparsity_pattern_from_row_major_coords(nmajor, nminor, coords) sparsity_pattern_from_row_major_coords(nmajor, nminor, coords)
}).boxed() })
.boxed()
} }
}) })
} }
/// A strategy for generating CSR matrices. /// A strategy for generating CSR matrices.
pub fn csr<T>(value_strategy: T, pub fn csr<T>(
value_strategy: T,
rows: impl Into<DimRange>, rows: impl Into<DimRange>,
cols: impl Into<DimRange>, cols: impl Into<DimRange>,
max_nonzeros: usize) max_nonzeros: usize,
-> impl Strategy<Value=CsrMatrix<T::Value>> ) -> impl Strategy<Value = CsrMatrix<T::Value>>
where where
T: Strategy + Clone + 'static, T: Strategy + Clone + 'static,
T::Value: Scalar, T::Value: Scalar,
{ {
let rows = rows.into(); let rows = rows.into();
let cols = cols.into(); let cols = cols.into();
sparsity_pattern(rows.lower_bound().value() ..= rows.upper_bound().value(), cols.lower_bound().value() ..= cols.upper_bound().value(), max_nonzeros) sparsity_pattern(
rows.lower_bound().value()..=rows.upper_bound().value(),
cols.lower_bound().value()..=cols.upper_bound().value(),
max_nonzeros,
)
.prop_flat_map(move |pattern| { .prop_flat_map(move |pattern| {
let nnz = pattern.nnz(); let nnz = pattern.nnz();
let values = vec![value_strategy.clone(); nnz]; let values = vec![value_strategy.clone(); nnz];
@ -327,18 +345,23 @@ where
} }
/// A strategy for generating CSC matrices. /// A strategy for generating CSC matrices.
pub fn csc<T>(value_strategy: T, pub fn csc<T>(
value_strategy: T,
rows: impl Into<DimRange>, rows: impl Into<DimRange>,
cols: impl Into<DimRange>, cols: impl Into<DimRange>,
max_nonzeros: usize) max_nonzeros: usize,
-> impl Strategy<Value=CscMatrix<T::Value>> ) -> impl Strategy<Value = CscMatrix<T::Value>>
where where
T: Strategy + Clone + 'static, T: Strategy + Clone + 'static,
T::Value: Scalar, T::Value: Scalar,
{ {
let rows = rows.into(); let rows = rows.into();
let cols = cols.into(); let cols = cols.into();
sparsity_pattern(cols.lower_bound().value() ..= cols.upper_bound().value(), rows.lower_bound().value() ..= rows.upper_bound().value(), max_nonzeros) sparsity_pattern(
cols.lower_bound().value()..=cols.upper_bound().value(),
rows.lower_bound().value()..=rows.upper_bound().value(),
max_nonzeros,
)
.prop_flat_map(move |pattern| { .prop_flat_map(move |pattern| {
let nnz = pattern.nnz(); let nnz = pattern.nnz();
let values = vec![value_strategy.clone(); nnz]; let values = vec![value_strategy.clone(); nnz];

View File

@ -22,11 +22,11 @@
*/ */
use proptest::strategy::{Strategy, Shuffleable, NewTree, ValueTree};
use proptest::test_runner::{TestRunner, TestRng};
use std::cell::Cell;
use proptest::num; use proptest::num;
use proptest::prelude::Rng; use proptest::prelude::Rng;
use proptest::strategy::{NewTree, Shuffleable, Strategy, ValueTree};
use proptest::test_runner::{TestRng, TestRunner};
use std::cell::Cell;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
#[must_use = "strategies do nothing unless used"] #[must_use = "strategies do nothing unless used"]

View File

@ -1,15 +1,15 @@
use proptest::strategy::Strategy;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::proptest::{csr, csc};
use nalgebra_sparse::csc::CscMatrix; use nalgebra_sparse::csc::CscMatrix;
use std::ops::RangeInclusive; use nalgebra_sparse::csr::CsrMatrix;
use std::convert::{TryFrom}; use nalgebra_sparse::proptest::{csc, csr};
use proptest::strategy::Strategy;
use std::convert::TryFrom;
use std::fmt::Debug; use std::fmt::Debug;
use std::ops::RangeInclusive;
#[macro_export] #[macro_export]
macro_rules! assert_panics { macro_rules! assert_panics {
($e:expr) => {{ ($e:expr) => {{
use std::panic::{catch_unwind}; use std::panic::catch_unwind;
use std::stringify; use std::stringify;
let expr_string = stringify!($e); let expr_string = stringify!($e);
@ -22,7 +22,10 @@ macro_rules! assert_panics {
let result = catch_unwind(|| $e); let result = catch_unwind(|| $e);
if result.is_ok() { if result.is_ok() {
panic!("assert_panics!({}) failed: the expression did not panic.", expr_string); panic!(
"assert_panics!({}) failed: the expression did not panic.",
expr_string
);
} }
}}; }};
} }
@ -34,14 +37,20 @@ pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5 ..= 5;
pub fn value_strategy<T>() -> RangeInclusive<T> pub fn value_strategy<T>() -> RangeInclusive<T>
where where
T: TryFrom<i32>, T: TryFrom<i32>,
T::Error: Debug T::Error: Debug,
{ {
let (start, end) = (PROPTEST_I32_VALUE_STRATEGY.start(), PROPTEST_I32_VALUE_STRATEGY.end()); let (start, end) = (
PROPTEST_I32_VALUE_STRATEGY.start(),
PROPTEST_I32_VALUE_STRATEGY.end(),
);
T::try_from(*start).unwrap()..=T::try_from(*end).unwrap() T::try_from(*start).unwrap()..=T::try_from(*end).unwrap()
} }
pub fn non_zero_i32_value_strategy() -> impl Strategy<Value = i32> { pub fn non_zero_i32_value_strategy() -> impl Strategy<Value = i32> {
let (start, end) = (PROPTEST_I32_VALUE_STRATEGY.start(), PROPTEST_I32_VALUE_STRATEGY.end()); let (start, end) = (
PROPTEST_I32_VALUE_STRATEGY.start(),
PROPTEST_I32_VALUE_STRATEGY.end(),
);
assert!(start < &0); assert!(start < &0);
assert!(end > &0); assert!(end > &0);
// Note: we don't use RangeInclusive for the second range, because then we'd have different // Note: we don't use RangeInclusive for the second range, because then we'd have different
@ -50,9 +59,19 @@ pub fn non_zero_i32_value_strategy() -> impl Strategy<Value=i32> {
} }
pub fn csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> { pub fn csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
csr(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ) csr(
PROPTEST_I32_VALUE_STRATEGY,
PROPTEST_MATRIX_DIM,
PROPTEST_MATRIX_DIM,
PROPTEST_MAX_NNZ,
)
} }
pub fn csc_strategy() -> impl Strategy<Value = CscMatrix<i32>> { pub fn csc_strategy() -> impl Strategy<Value = CscMatrix<i32>> {
csc(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ) csc(
PROPTEST_I32_VALUE_STRATEGY,
PROPTEST_MATRIX_DIM,
PROPTEST_MATRIX_DIM,
PROPTEST_MAX_NNZ,
)
} }

View File

@ -1,17 +1,16 @@
use nalgebra_sparse::coo::CooMatrix;
use nalgebra_sparse::convert::serial::{convert_coo_dense, convert_coo_csr,
convert_dense_coo, convert_csr_dense,
convert_csr_coo, convert_dense_csr,
convert_csc_coo, convert_coo_csc,
convert_csc_dense, convert_dense_csc,
convert_csr_csc, convert_csc_csr};
use nalgebra_sparse::proptest::{coo_with_duplicates, coo_no_duplicates, csr, csc};
use nalgebra::proptest::matrix;
use proptest::prelude::*;
use nalgebra::DMatrix;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::csc::CscMatrix;
use crate::common::csc_strategy; use crate::common::csc_strategy;
use nalgebra::proptest::matrix;
use nalgebra::DMatrix;
use nalgebra_sparse::convert::serial::{
convert_coo_csc, convert_coo_csr, convert_coo_dense, convert_csc_coo, convert_csc_csr,
convert_csc_dense, convert_csr_coo, convert_csr_csc, convert_csr_dense, convert_dense_coo,
convert_dense_csc, convert_dense_csr,
};
use nalgebra_sparse::coo::CooMatrix;
use nalgebra_sparse::csc::CscMatrix;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::proptest::{coo_no_duplicates, coo_with_duplicates, csc, csr};
use proptest::prelude::*;
#[test] #[test]
fn test_convert_dense_coo() { fn test_convert_dense_coo() {
@ -41,15 +40,16 @@ fn test_convert_dense_coo() {
// Here we implicitly test that the coo matrix is indeed constructed from column-major // Here we implicitly test that the coo matrix is indeed constructed from column-major
// iteration of the dense matrix. // iteration of the dense matrix.
let dense = DMatrix::from_row_slice(2, 3, entries); let dense = DMatrix::from_row_slice(2, 3, entries);
let coo_no_dup = CooMatrix::try_from_triplets(2, 3, let coo_no_dup =
vec![0, 1, 0], CooMatrix::try_from_triplets(2, 3, vec![0, 1, 0], vec![0, 1, 2], vec![1, 5, 3])
vec![0, 1, 2],
vec![1, 5, 3])
.unwrap(); .unwrap();
let coo_dup = CooMatrix::try_from_triplets(2, 3, let coo_dup = CooMatrix::try_from_triplets(
2,
3,
vec![0, 1, 0, 1], vec![0, 1, 0, 1],
vec![0, 1, 2, 1], vec![0, 1, 2, 1],
vec![1, -2, 3, 7]) vec![1, -2, 3, 7],
)
.unwrap(); .unwrap();
assert_eq!(CooMatrix::from(&dense), coo_no_dup); assert_eq!(CooMatrix::from(&dense), coo_no_dup);
@ -76,8 +76,9 @@ fn test_convert_coo_csr() {
4, 4,
vec![0, 1, 2, 5], vec![0, 1, 2, 5],
vec![1, 3, 0, 2, 3], vec![1, 3, 0, 2, 3],
vec![2, 4, 1, 1, 2] vec![2, 4, 1, 1, 2],
).unwrap(); )
.unwrap();
assert_eq!(convert_coo_csr(&coo), expected_csr); assert_eq!(convert_coo_csr(&coo), expected_csr);
} }
@ -101,8 +102,9 @@ fn test_convert_coo_csr() {
4, 4,
vec![0, 1, 2, 5], vec![0, 1, 2, 5],
vec![1, 3, 0, 2, 3], vec![1, 3, 0, 2, 3],
vec![5, 4, 1, 1, 4] vec![5, 4, 1, 1, 4],
).unwrap(); )
.unwrap();
assert_eq!(convert_coo_csr(&coo), expected_csr); assert_eq!(convert_coo_csr(&coo), expected_csr);
} }
@ -115,16 +117,18 @@ fn test_convert_csr_coo() {
4, 4,
vec![0, 1, 2, 5], vec![0, 1, 2, 5],
vec![1, 3, 0, 2, 3], vec![1, 3, 0, 2, 3],
vec![5, 4, 1, 1, 4] vec![5, 4, 1, 1, 4],
).unwrap(); )
.unwrap();
let expected_coo = CooMatrix::try_from_triplets( let expected_coo = CooMatrix::try_from_triplets(
3, 3,
4, 4,
vec![0, 1, 2, 2, 2], vec![0, 1, 2, 2, 2],
vec![1, 3, 0, 2, 3], vec![1, 3, 0, 2, 3],
vec![5, 4, 1, 1, 4] vec![5, 4, 1, 1, 4],
).unwrap(); )
.unwrap();
assert_eq!(convert_csr_coo(&csr), expected_coo); assert_eq!(convert_csr_coo(&csr), expected_coo);
} }
@ -148,8 +152,9 @@ fn test_convert_coo_csc() {
4, 4,
vec![0, 1, 2, 3, 5], vec![0, 1, 2, 3, 5],
vec![2, 0, 2, 1, 2], vec![2, 0, 2, 1, 2],
vec![1, 2, 1, 4, 2] vec![1, 2, 1, 4, 2],
).unwrap(); )
.unwrap();
assert_eq!(convert_coo_csc(&coo), expected_csc); assert_eq!(convert_coo_csc(&coo), expected_csc);
} }
@ -173,8 +178,9 @@ fn test_convert_coo_csc() {
4, 4,
vec![0, 1, 2, 3, 5], vec![0, 1, 2, 3, 5],
vec![2, 0, 2, 1, 2], vec![2, 0, 2, 1, 2],
vec![1, 5, 1, 4, 4] vec![1, 5, 1, 4, 4],
).unwrap(); )
.unwrap();
assert_eq!(convert_coo_csc(&coo), expected_csc); assert_eq!(convert_coo_csc(&coo), expected_csc);
} }
@ -187,16 +193,18 @@ fn test_convert_csc_coo() {
4, 4,
vec![0, 1, 2, 3, 5], vec![0, 1, 2, 3, 5],
vec![2, 0, 2, 1, 2], vec![2, 0, 2, 1, 2],
vec![1, 2, 1, 4, 2] vec![1, 2, 1, 4, 2],
).unwrap(); )
.unwrap();
let expected_coo = CooMatrix::try_from_triplets( let expected_coo = CooMatrix::try_from_triplets(
3, 3,
4, 4,
vec![2, 0, 2, 1, 2], vec![2, 0, 2, 1, 2],
vec![0, 1, 2, 3, 3], vec![0, 1, 2, 3, 3],
vec![1, 2, 1, 4, 2] vec![1, 2, 1, 4, 2],
).unwrap(); )
.unwrap();
assert_eq!(convert_csc_coo(&csc), expected_coo); assert_eq!(convert_csc_coo(&csc), expected_coo);
} }
@ -209,7 +217,8 @@ fn test_convert_csr_csc_bidirectional() {
vec![0, 3, 4, 6], vec![0, 3, 4, 6],
vec![1, 2, 3, 0, 1, 3], vec![1, 2, 3, 0, 1, 3],
vec![5, 3, 2, 2, 1, 4], vec![5, 3, 2, 2, 1, 4],
).unwrap(); )
.unwrap();
let csc = CscMatrix::try_from_csc_data( let csc = CscMatrix::try_from_csc_data(
3, 3,
@ -217,7 +226,8 @@ fn test_convert_csr_csc_bidirectional() {
vec![0, 1, 3, 4, 6], vec![0, 1, 3, 4, 6],
vec![1, 0, 2, 0, 0, 2], vec![1, 0, 2, 0, 0, 2],
vec![2, 5, 1, 3, 2, 4], vec![2, 5, 1, 3, 2, 4],
).unwrap(); )
.unwrap();
assert_eq!(convert_csr_csc(&csr), csc); assert_eq!(convert_csr_csc(&csr), csc);
assert_eq!(convert_csc_csr(&csc), csr); assert_eq!(convert_csc_csr(&csc), csr);
@ -231,7 +241,8 @@ fn test_convert_csr_dense_bidirectional() {
vec![0, 3, 4, 6], vec![0, 3, 4, 6],
vec![1, 2, 3, 0, 1, 3], vec![1, 2, 3, 0, 1, 3],
vec![5, 3, 2, 2, 1, 4], vec![5, 3, 2, 2, 1, 4],
).unwrap(); )
.unwrap();
#[rustfmt::skip] #[rustfmt::skip]
let dense = DMatrix::from_row_slice(3, 4, &[ let dense = DMatrix::from_row_slice(3, 4, &[
@ -252,7 +263,8 @@ fn test_convert_csc_dense_bidirectional() {
vec![0, 1, 3, 4, 6], vec![0, 1, 3, 4, 6],
vec![1, 0, 2, 0, 0, 2], vec![1, 0, 2, 0, 0, 2],
vec![2, 5, 1, 3, 2, 4], vec![2, 5, 1, 3, 2, 4],
).unwrap(); )
.unwrap();
#[rustfmt::skip] #[rustfmt::skip]
let dense = DMatrix::from_row_slice(3, 4, &[ let dense = DMatrix::from_row_slice(3, 4, &[

View File

@ -1,7 +1,7 @@
use nalgebra_sparse::{SparseFormatErrorKind};
use nalgebra_sparse::coo::CooMatrix;
use nalgebra::DMatrix;
use crate::assert_panics; use crate::assert_panics;
use nalgebra::DMatrix;
use nalgebra_sparse::coo::CooMatrix;
use nalgebra_sparse::SparseFormatErrorKind;
#[test] #[test]
fn coo_construction_for_valid_data() { fn coo_construction_for_valid_data() {
@ -10,8 +10,8 @@ fn coo_construction_for_valid_data() {
{ {
// Zero matrix // Zero matrix
let coo = CooMatrix::<i32>::try_from_triplets(3, 2, Vec::new(), Vec::new(), Vec::new()) let coo =
.unwrap(); CooMatrix::<i32>::try_from_triplets(3, 2, Vec::new(), Vec::new(), Vec::new()).unwrap();
assert_eq!(coo.nrows(), 3); assert_eq!(coo.nrows(), 3);
assert_eq!(coo.ncols(), 2); assert_eq!(coo.ncols(), 2);
assert!(coo.triplet_iter().next().is_none()); assert!(coo.triplet_iter().next().is_none());
@ -27,8 +27,8 @@ fn coo_construction_for_valid_data() {
let i = vec![0, 1, 0, 0, 2]; let i = vec![0, 1, 0, 0, 2];
let j = vec![0, 2, 1, 3, 3]; let j = vec![0, 2, 1, 3, 3];
let v = vec![2, 3, 7, 3, 1]; let v = vec![2, 3, 7, 3, 1];
let coo = CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()) let coo =
.unwrap(); CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap();
assert_eq!(coo.nrows(), 3); assert_eq!(coo.nrows(), 3);
assert_eq!(coo.ncols(), 5); assert_eq!(coo.ncols(), 5);
@ -59,8 +59,8 @@ fn coo_construction_for_valid_data() {
let i = vec![0, 1, 0, 0, 0, 0, 2, 1]; let i = vec![0, 1, 0, 0, 0, 0, 2, 1];
let j = vec![0, 2, 0, 1, 0, 3, 3, 2]; let j = vec![0, 2, 0, 1, 0, 3, 3, 2];
let v = vec![2, 3, 4, 7, 1, 3, 1, 5]; let v = vec![2, 3, 4, 7, 1, 3, 1, 5];
let coo = CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()) let coo =
.unwrap(); CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap();
assert_eq!(coo.nrows(), 3); assert_eq!(coo.nrows(), 3);
assert_eq!(coo.ncols(), 5); assert_eq!(coo.ncols(), 5);
@ -92,25 +92,37 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
{ {
// 0x0 matrix // 0x0 matrix
let result = CooMatrix::<i32>::try_from_triplets(0, 0, vec![0], vec![0], vec![2]); let result = CooMatrix::<i32>::try_from_triplets(0, 0, vec![0], vec![0], vec![2]);
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds)); assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
} }
{ {
// 1x1 matrix, row out of bounds // 1x1 matrix, row out of bounds
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![0], vec![2]); let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![0], vec![2]);
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds)); assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
} }
{ {
// 1x1 matrix, col out of bounds // 1x1 matrix, col out of bounds
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![0], vec![1], vec![2]); let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![0], vec![1], vec![2]);
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds)); assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
} }
{ {
// 1x1 matrix, row and col out of bounds // 1x1 matrix, row and col out of bounds
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![1], vec![2]); let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![1], vec![2]);
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds)); assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
} }
{ {
@ -119,7 +131,10 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
let j = vec![0, 2, 1, 3, 3]; let j = vec![0, 2, 1, 3, 3];
let v = vec![2, 3, 7, 3, 1]; let v = vec![2, 3, 7, 3, 1];
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v); let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds)); assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
} }
{ {
@ -128,7 +143,10 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
let j = vec![0, 2, 1, 5, 3]; let j = vec![0, 2, 1, 5, 3];
let v = vec![2, 3, 7, 3, 1]; let v = vec![2, 3, 7, 3, 1];
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v); let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds)); assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
} }
} }
@ -137,16 +155,55 @@ fn coo_try_from_triplets_panics_on_mismatched_vectors() {
// Check that try_from_triplets panics when the triplet vectors have different lengths // Check that try_from_triplets panics when the triplet vectors have different lengths
macro_rules! assert_errs { macro_rules! assert_errs {
($result:expr) => { ($result:expr) => {
assert!(matches!($result.unwrap_err().kind(), SparseFormatErrorKind::InvalidStructure)) assert!(matches!(
} $result.unwrap_err().kind(),
SparseFormatErrorKind::InvalidStructure
))
};
} }
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1, 2], vec![0], vec![0])); assert_errs!(CooMatrix::<i32>::try_from_triplets(
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1], vec![0, 0], vec![0])); 3,
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1], vec![0], vec![0, 1])); 5,
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1, 2], vec![0, 1], vec![0])); vec![1, 2],
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1], vec![0, 1], vec![0, 1])); vec![0],
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1, 1], vec![0], vec![0, 1])); vec![0]
));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1],
vec![0, 0],
vec![0]
));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1],
vec![0],
vec![0, 1]
));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1, 2],
vec![0, 1],
vec![0]
));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1],
vec![0, 1],
vec![0, 1]
));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1, 1],
vec![0],
vec![0, 1]
));
} }
#[test] #[test]
@ -157,10 +214,16 @@ fn coo_push_valid_entries() {
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1)]); assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1)]);
coo.push(0, 0, 2); coo.push(0, 0, 2);
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1), (0, 0, &2)]); assert_eq!(
coo.triplet_iter().collect::<Vec<_>>(),
vec![(0, 0, &1), (0, 0, &2)]
);
coo.push(2, 2, 3); coo.push(2, 2, 3);
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1), (0, 0, &2), (2, 2, &3)]); assert_eq!(
coo.triplet_iter().collect::<Vec<_>>(),
vec![(0, 0, &1), (0, 0, &2), (2, 2, &3)]
);
} }
#[test] #[test]

View File

@ -1,6 +1,6 @@
use nalgebra::DMatrix;
use nalgebra_sparse::csc::CscMatrix; use nalgebra_sparse::csc::CscMatrix;
use nalgebra_sparse::SparseFormatErrorKind; use nalgebra_sparse::SparseFormatErrorKind;
use nalgebra::DMatrix;
use proptest::prelude::*; use proptest::prelude::*;
use proptest::sample::subsequence; use proptest::sample::subsequence;
@ -42,7 +42,10 @@ fn csc_matrix_valid_data() {
assert_eq!(matrix.col_mut(0).row_indices(), &[]); assert_eq!(matrix.col_mut(0).row_indices(), &[]);
assert_eq!(matrix.col_mut(0).values(), &[]); assert_eq!(matrix.col_mut(0).values(), &[]);
assert_eq!(matrix.col_mut(0).values_mut(), &[]); assert_eq!(matrix.col_mut(0).values_mut(), &[]);
assert_eq!(matrix.col_mut(0).rows_and_values_mut(), ([].as_ref(), [].as_mut())); assert_eq!(
matrix.col_mut(0).rows_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.col(1).nrows(), 2); assert_eq!(matrix.col(1).nrows(), 2);
assert_eq!(matrix.col(1).nnz(), 0); assert_eq!(matrix.col(1).nnz(), 0);
@ -53,7 +56,10 @@ fn csc_matrix_valid_data() {
assert_eq!(matrix.col_mut(1).row_indices(), &[]); assert_eq!(matrix.col_mut(1).row_indices(), &[]);
assert_eq!(matrix.col_mut(1).values(), &[]); assert_eq!(matrix.col_mut(1).values(), &[]);
assert_eq!(matrix.col_mut(1).values_mut(), &[]); assert_eq!(matrix.col_mut(1).values_mut(), &[]);
assert_eq!(matrix.col_mut(1).rows_and_values_mut(), ([].as_ref(), [].as_mut())); assert_eq!(
matrix.col_mut(1).rows_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.col(2).nrows(), 2); assert_eq!(matrix.col(2).nrows(), 2);
assert_eq!(matrix.col(2).nnz(), 0); assert_eq!(matrix.col(2).nnz(), 0);
@ -64,7 +70,10 @@ fn csc_matrix_valid_data() {
assert_eq!(matrix.col_mut(2).row_indices(), &[]); assert_eq!(matrix.col_mut(2).row_indices(), &[]);
assert_eq!(matrix.col_mut(2).values(), &[]); assert_eq!(matrix.col_mut(2).values(), &[]);
assert_eq!(matrix.col_mut(2).values_mut(), &[]); assert_eq!(matrix.col_mut(2).values_mut(), &[]);
assert_eq!(matrix.col_mut(2).rows_and_values_mut(), ([].as_ref(), [].as_mut())); assert_eq!(
matrix.col_mut(2).rows_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert!(matrix.get_col(3).is_none()); assert!(matrix.get_col(3).is_none());
assert!(matrix.get_col_mut(3).is_none()); assert!(matrix.get_col_mut(3).is_none());
@ -81,11 +90,9 @@ fn csc_matrix_valid_data() {
let offsets = vec![0, 2, 2, 5]; let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3]; let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let mut matrix = CscMatrix::try_from_csc_data(6, let mut matrix =
3, CscMatrix::try_from_csc_data(6, 3, offsets.clone(), indices.clone(), values.clone())
offsets.clone(), .unwrap();
indices.clone(),
values.clone()).unwrap();
assert_eq!(matrix.nrows(), 6); assert_eq!(matrix.nrows(), 6);
assert_eq!(matrix.ncols(), 3); assert_eq!(matrix.ncols(), 3);
@ -95,10 +102,20 @@ fn csc_matrix_valid_data() {
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]); assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
let expected_triplets = vec![(0, 0, 0), (5, 0, 1), (1, 2, 2), (2, 2, 3), (3, 2, 4)]; let expected_triplets = vec![(0, 0, 0), (5, 0, 1), (1, 2, 2), (2, 2, 3), (3, 2, 4)];
assert_eq!(matrix.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(), assert_eq!(
expected_triplets); matrix
assert_eq!(matrix.triplet_iter_mut().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(), .triplet_iter()
expected_triplets); .map(|(i, j, v)| (i, j, *v))
.collect::<Vec<_>>(),
expected_triplets
);
assert_eq!(
matrix
.triplet_iter_mut()
.map(|(i, j, v)| (i, j, *v))
.collect::<Vec<_>>(),
expected_triplets
);
assert_eq!(matrix.col(0).nrows(), 6); assert_eq!(matrix.col(0).nrows(), 6);
assert_eq!(matrix.col(0).nnz(), 2); assert_eq!(matrix.col(0).nnz(), 2);
@ -109,7 +126,10 @@ fn csc_matrix_valid_data() {
assert_eq!(matrix.col_mut(0).row_indices(), &[0, 5]); assert_eq!(matrix.col_mut(0).row_indices(), &[0, 5]);
assert_eq!(matrix.col_mut(0).values(), &[0, 1]); assert_eq!(matrix.col_mut(0).values(), &[0, 1]);
assert_eq!(matrix.col_mut(0).values_mut(), &[0, 1]); assert_eq!(matrix.col_mut(0).values_mut(), &[0, 1]);
assert_eq!(matrix.col_mut(0).rows_and_values_mut(), ([0, 5].as_ref(), [0, 1].as_mut())); assert_eq!(
matrix.col_mut(0).rows_and_values_mut(),
([0, 5].as_ref(), [0, 1].as_mut())
);
assert_eq!(matrix.col(1).nrows(), 6); assert_eq!(matrix.col(1).nrows(), 6);
assert_eq!(matrix.col(1).nnz(), 0); assert_eq!(matrix.col(1).nnz(), 0);
@ -120,7 +140,10 @@ fn csc_matrix_valid_data() {
assert_eq!(matrix.col_mut(1).row_indices(), &[]); assert_eq!(matrix.col_mut(1).row_indices(), &[]);
assert_eq!(matrix.col_mut(1).values(), &[]); assert_eq!(matrix.col_mut(1).values(), &[]);
assert_eq!(matrix.col_mut(1).values_mut(), &[]); assert_eq!(matrix.col_mut(1).values_mut(), &[]);
assert_eq!(matrix.col_mut(1).rows_and_values_mut(), ([].as_ref(), [].as_mut())); assert_eq!(
matrix.col_mut(1).rows_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.col(2).nrows(), 6); assert_eq!(matrix.col(2).nrows(), 6);
assert_eq!(matrix.col(2).nnz(), 3); assert_eq!(matrix.col(2).nnz(), 3);
@ -131,7 +154,10 @@ fn csc_matrix_valid_data() {
assert_eq!(matrix.col_mut(2).row_indices(), &[1, 2, 3]); assert_eq!(matrix.col_mut(2).row_indices(), &[1, 2, 3]);
assert_eq!(matrix.col_mut(2).values(), &[2, 3, 4]); assert_eq!(matrix.col_mut(2).values(), &[2, 3, 4]);
assert_eq!(matrix.col_mut(2).values_mut(), &[2, 3, 4]); assert_eq!(matrix.col_mut(2).values_mut(), &[2, 3, 4]);
assert_eq!(matrix.col_mut(2).rows_and_values_mut(), ([1, 2, 3].as_ref(), [2, 3, 4].as_mut())); assert_eq!(
matrix.col_mut(2).rows_and_values_mut(),
([1, 2, 3].as_ref(), [2, 3, 4].as_mut())
);
assert!(matrix.get_col(3).is_none()); assert!(matrix.get_col(3).is_none());
assert!(matrix.get_col_mut(3).is_none()); assert!(matrix.get_col_mut(3).is_none());
@ -146,11 +172,13 @@ fn csc_matrix_valid_data() {
#[test] #[test]
fn csc_matrix_try_from_invalid_csc_data() { fn csc_matrix_try_from_invalid_csc_data() {
{ {
// Empty offset array (invalid length) // Empty offset array (invalid length)
let matrix = CscMatrix::try_from_csc_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new()); let matrix = CscMatrix::try_from_csc_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
} }
{ {
@ -160,7 +188,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values); let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
} }
{ {
@ -169,7 +200,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
let indices = vec![0, 5, 1, 2, 3]; let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values); let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
} }
{ {
@ -178,7 +212,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
let indices = vec![0, 5, 1, 2, 3]; let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values); let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
} }
{ {
@ -187,7 +224,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
let indices = vec![0, 5, 1, 2, 3]; let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values); let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
} }
{ {
@ -196,7 +236,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
let indices = vec![0, 1, 2, 3, 4]; let indices = vec![0, 1, 2, 3, 4];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values); let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
} }
{ {
@ -205,7 +248,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
let indices = vec![0, 2, 3, 1, 4]; let indices = vec![0, 2, 3, 1, 4];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values); let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
} }
{ {
@ -214,7 +260,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
let indices = vec![0, 6, 1, 2, 3]; let indices = vec![0, 6, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values); let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::IndexOutOfBounds); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::IndexOutOfBounds
);
} }
{ {
@ -223,9 +272,11 @@ fn csc_matrix_try_from_invalid_csc_data() {
let indices = vec![0, 5, 2, 2, 3]; let indices = vec![0, 5, 2, 2, 3];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values); let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::DuplicateEntry); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::DuplicateEntry
);
} }
} }
#[test] #[test]
@ -239,11 +290,7 @@ fn csc_disassemble_avoids_clone_when_owned() {
let offsets_ptr = offsets.as_ptr(); let offsets_ptr = offsets.as_ptr();
let indices_ptr = indices.as_ptr(); let indices_ptr = indices.as_ptr();
let values_ptr = values.as_ptr(); let values_ptr = values.as_ptr();
let matrix = CscMatrix::try_from_csc_data(6, let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values).unwrap();
3,
offsets,
indices,
values).unwrap();
let (offsets, indices, values) = matrix.disassemble(); let (offsets, indices, values) = matrix.disassemble();
assert_eq!(offsets.as_ptr(), offsets_ptr); assert_eq!(offsets.as_ptr(), offsets_ptr);

View File

@ -1,6 +1,6 @@
use nalgebra::DMatrix;
use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::SparseFormatErrorKind; use nalgebra_sparse::SparseFormatErrorKind;
use nalgebra::DMatrix;
use proptest::prelude::*; use proptest::prelude::*;
use proptest::sample::subsequence; use proptest::sample::subsequence;
@ -9,7 +9,6 @@ use crate::common::csr_strategy;
use std::collections::HashSet; use std::collections::HashSet;
#[test] #[test]
fn csr_matrix_valid_data() { fn csr_matrix_valid_data() {
// Construct matrix from valid data and check that selected methods return results // Construct matrix from valid data and check that selected methods return results
@ -43,7 +42,10 @@ fn csr_matrix_valid_data() {
assert_eq!(matrix.row_mut(0).col_indices(), &[]); assert_eq!(matrix.row_mut(0).col_indices(), &[]);
assert_eq!(matrix.row_mut(0).values(), &[]); assert_eq!(matrix.row_mut(0).values(), &[]);
assert_eq!(matrix.row_mut(0).values_mut(), &[]); assert_eq!(matrix.row_mut(0).values_mut(), &[]);
assert_eq!(matrix.row_mut(0).cols_and_values_mut(), ([].as_ref(), [].as_mut())); assert_eq!(
matrix.row_mut(0).cols_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.row(1).ncols(), 2); assert_eq!(matrix.row(1).ncols(), 2);
assert_eq!(matrix.row(1).nnz(), 0); assert_eq!(matrix.row(1).nnz(), 0);
@ -54,7 +56,10 @@ fn csr_matrix_valid_data() {
assert_eq!(matrix.row_mut(1).col_indices(), &[]); assert_eq!(matrix.row_mut(1).col_indices(), &[]);
assert_eq!(matrix.row_mut(1).values(), &[]); assert_eq!(matrix.row_mut(1).values(), &[]);
assert_eq!(matrix.row_mut(1).values_mut(), &[]); assert_eq!(matrix.row_mut(1).values_mut(), &[]);
assert_eq!(matrix.row_mut(1).cols_and_values_mut(), ([].as_ref(), [].as_mut())); assert_eq!(
matrix.row_mut(1).cols_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.row(2).ncols(), 2); assert_eq!(matrix.row(2).ncols(), 2);
assert_eq!(matrix.row(2).nnz(), 0); assert_eq!(matrix.row(2).nnz(), 0);
@ -65,7 +70,10 @@ fn csr_matrix_valid_data() {
assert_eq!(matrix.row_mut(2).col_indices(), &[]); assert_eq!(matrix.row_mut(2).col_indices(), &[]);
assert_eq!(matrix.row_mut(2).values(), &[]); assert_eq!(matrix.row_mut(2).values(), &[]);
assert_eq!(matrix.row_mut(2).values_mut(), &[]); assert_eq!(matrix.row_mut(2).values_mut(), &[]);
assert_eq!(matrix.row_mut(2).cols_and_values_mut(), ([].as_ref(), [].as_mut())); assert_eq!(
matrix.row_mut(2).cols_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert!(matrix.get_row(3).is_none()); assert!(matrix.get_row(3).is_none());
assert!(matrix.get_row_mut(3).is_none()); assert!(matrix.get_row_mut(3).is_none());
@ -82,11 +90,9 @@ fn csr_matrix_valid_data() {
let offsets = vec![0, 2, 2, 5]; let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3]; let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let mut matrix = CsrMatrix::try_from_csr_data(3, let mut matrix =
6, CsrMatrix::try_from_csr_data(3, 6, offsets.clone(), indices.clone(), values.clone())
offsets.clone(), .unwrap();
indices.clone(),
values.clone()).unwrap();
assert_eq!(matrix.nrows(), 3); assert_eq!(matrix.nrows(), 3);
assert_eq!(matrix.ncols(), 6); assert_eq!(matrix.ncols(), 6);
@ -96,10 +102,20 @@ fn csr_matrix_valid_data() {
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]); assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
let expected_triplets = vec![(0, 0, 0), (0, 5, 1), (2, 1, 2), (2, 2, 3), (2, 3, 4)]; let expected_triplets = vec![(0, 0, 0), (0, 5, 1), (2, 1, 2), (2, 2, 3), (2, 3, 4)];
assert_eq!(matrix.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(), assert_eq!(
expected_triplets); matrix
assert_eq!(matrix.triplet_iter_mut().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(), .triplet_iter()
expected_triplets); .map(|(i, j, v)| (i, j, *v))
.collect::<Vec<_>>(),
expected_triplets
);
assert_eq!(
matrix
.triplet_iter_mut()
.map(|(i, j, v)| (i, j, *v))
.collect::<Vec<_>>(),
expected_triplets
);
assert_eq!(matrix.row(0).ncols(), 6); assert_eq!(matrix.row(0).ncols(), 6);
assert_eq!(matrix.row(0).nnz(), 2); assert_eq!(matrix.row(0).nnz(), 2);
@ -110,7 +126,10 @@ fn csr_matrix_valid_data() {
assert_eq!(matrix.row_mut(0).col_indices(), &[0, 5]); assert_eq!(matrix.row_mut(0).col_indices(), &[0, 5]);
assert_eq!(matrix.row_mut(0).values(), &[0, 1]); assert_eq!(matrix.row_mut(0).values(), &[0, 1]);
assert_eq!(matrix.row_mut(0).values_mut(), &[0, 1]); assert_eq!(matrix.row_mut(0).values_mut(), &[0, 1]);
assert_eq!(matrix.row_mut(0).cols_and_values_mut(), ([0, 5].as_ref(), [0, 1].as_mut())); assert_eq!(
matrix.row_mut(0).cols_and_values_mut(),
([0, 5].as_ref(), [0, 1].as_mut())
);
assert_eq!(matrix.row(1).ncols(), 6); assert_eq!(matrix.row(1).ncols(), 6);
assert_eq!(matrix.row(1).nnz(), 0); assert_eq!(matrix.row(1).nnz(), 0);
@ -121,7 +140,10 @@ fn csr_matrix_valid_data() {
assert_eq!(matrix.row_mut(1).col_indices(), &[]); assert_eq!(matrix.row_mut(1).col_indices(), &[]);
assert_eq!(matrix.row_mut(1).values(), &[]); assert_eq!(matrix.row_mut(1).values(), &[]);
assert_eq!(matrix.row_mut(1).values_mut(), &[]); assert_eq!(matrix.row_mut(1).values_mut(), &[]);
assert_eq!(matrix.row_mut(1).cols_and_values_mut(), ([].as_ref(), [].as_mut())); assert_eq!(
matrix.row_mut(1).cols_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.row(2).ncols(), 6); assert_eq!(matrix.row(2).ncols(), 6);
assert_eq!(matrix.row(2).nnz(), 3); assert_eq!(matrix.row(2).nnz(), 3);
@ -132,7 +154,10 @@ fn csr_matrix_valid_data() {
assert_eq!(matrix.row_mut(2).col_indices(), &[1, 2, 3]); assert_eq!(matrix.row_mut(2).col_indices(), &[1, 2, 3]);
assert_eq!(matrix.row_mut(2).values(), &[2, 3, 4]); assert_eq!(matrix.row_mut(2).values(), &[2, 3, 4]);
assert_eq!(matrix.row_mut(2).values_mut(), &[2, 3, 4]); assert_eq!(matrix.row_mut(2).values_mut(), &[2, 3, 4]);
assert_eq!(matrix.row_mut(2).cols_and_values_mut(), ([1, 2, 3].as_ref(), [2, 3, 4].as_mut())); assert_eq!(
matrix.row_mut(2).cols_and_values_mut(),
([1, 2, 3].as_ref(), [2, 3, 4].as_mut())
);
assert!(matrix.get_row(3).is_none()); assert!(matrix.get_row(3).is_none());
assert!(matrix.get_row_mut(3).is_none()); assert!(matrix.get_row_mut(3).is_none());
@ -147,11 +172,13 @@ fn csr_matrix_valid_data() {
#[test] #[test]
fn csr_matrix_try_from_invalid_csr_data() { fn csr_matrix_try_from_invalid_csr_data() {
{ {
// Empty offset array (invalid length) // Empty offset array (invalid length)
let matrix = CsrMatrix::try_from_csr_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new()); let matrix = CsrMatrix::try_from_csr_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
} }
{ {
@ -161,7 +188,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values); let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
} }
{ {
@ -170,7 +200,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
let indices = vec![0, 5, 1, 2, 3]; let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values); let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
} }
{ {
@ -179,7 +212,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
let indices = vec![0, 5, 1, 2, 3]; let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values); let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
} }
{ {
@ -188,7 +224,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
let indices = vec![0, 5, 1, 2, 3]; let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values); let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
} }
{ {
@ -197,7 +236,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
let indices = vec![0, 1, 2, 3, 4]; let indices = vec![0, 1, 2, 3, 4];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values); let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
} }
{ {
@ -206,7 +248,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
let indices = vec![0, 2, 3, 1, 4]; let indices = vec![0, 2, 3, 1, 4];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values); let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
} }
{ {
@ -215,7 +260,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
let indices = vec![0, 6, 1, 2, 3]; let indices = vec![0, 6, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values); let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::IndexOutOfBounds); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::IndexOutOfBounds
);
} }
{ {
@ -224,9 +272,11 @@ fn csr_matrix_try_from_invalid_csr_data() {
let indices = vec![0, 5, 2, 2, 3]; let indices = vec![0, 5, 2, 2, 3];
let values = vec![0, 1, 2, 3, 4]; let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values); let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::DuplicateEntry); assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::DuplicateEntry
);
} }
} }
#[test] #[test]
@ -240,11 +290,7 @@ fn csr_disassemble_avoids_clone_when_owned() {
let offsets_ptr = offsets.as_ptr(); let offsets_ptr = offsets.as_ptr();
let indices_ptr = indices.as_ptr(); let indices_ptr = indices.as_ptr();
let values_ptr = values.as_ptr(); let values_ptr = values.as_ptr();
let matrix = CsrMatrix::try_from_csr_data(3, let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values).unwrap();
6,
offsets,
indices,
values).unwrap();
let (offsets, indices, values) = matrix.disassemble(); let (offsets, indices, values) = matrix.disassemble();
assert_eq!(offsets.as_ptr(), offsets_ptr); assert_eq!(offsets.as_ptr(), offsets_ptr);

View File

@ -1,8 +1,8 @@
mod coo;
mod cholesky; mod cholesky;
mod convert_serial;
mod coo;
mod csc;
mod csr;
mod ops; mod ops;
mod pattern; mod pattern;
mod csr;
mod csc;
mod convert_serial;
mod proptest; mod proptest;

View File

@ -1,13 +1,19 @@
use crate::common::{csc_strategy, csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ, PROPTEST_I32_VALUE_STRATEGY, non_zero_i32_value_strategy, value_strategy}; use crate::common::{
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spadd_csr_prealloc, spadd_csc_prealloc, spmm_csr_prealloc, spmm_csc_prealloc, spsolve_csc_lower_triangular, spmm_csr_pattern}; csc_strategy, csr_strategy, non_zero_i32_value_strategy, value_strategy,
use nalgebra_sparse::ops::{Op}; PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
use nalgebra_sparse::csr::CsrMatrix; };
use nalgebra_sparse::csc::CscMatrix; use nalgebra_sparse::csc::CscMatrix;
use nalgebra_sparse::proptest::{csc, csr, sparsity_pattern}; use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::ops::serial::{
spadd_csc_prealloc, spadd_csr_prealloc, spadd_pattern, spmm_csc_dense, spmm_csc_prealloc,
spmm_csr_dense, spmm_csr_pattern, spmm_csr_prealloc, spsolve_csc_lower_triangular,
};
use nalgebra_sparse::ops::Op;
use nalgebra_sparse::pattern::SparsityPattern; use nalgebra_sparse::pattern::SparsityPattern;
use nalgebra_sparse::proptest::{csc, csr, sparsity_pattern};
use nalgebra::{DMatrix, Scalar, DMatrixSliceMut, DMatrixSlice};
use nalgebra::proptest::{matrix, vector}; use nalgebra::proptest::{matrix, vector};
use nalgebra::{DMatrix, DMatrixSlice, DMatrixSliceMut, Scalar};
use proptest::prelude::*; use proptest::prelude::*;
@ -17,19 +23,15 @@ use std::panic::catch_unwind;
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1 /// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> { fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
let boolean_csr = CsrMatrix::try_from_pattern_and_values( let boolean_csr =
pattern.clone(), CsrMatrix::try_from_pattern_and_values(pattern.clone(), vec![1; pattern.nnz()]).unwrap();
vec![1; pattern.nnz()])
.unwrap();
DMatrix::from(&boolean_csr) DMatrix::from(&boolean_csr)
} }
/// Represents the sparsity pattern of a CSC matrix as a dense matrix with 0/1 /// Represents the sparsity pattern of a CSC matrix as a dense matrix with 0/1
fn dense_csc_pattern(pattern: &SparsityPattern) -> DMatrix<i32> { fn dense_csc_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
let boolean_csc = CscMatrix::try_from_pattern_and_values( let boolean_csc =
pattern.clone(), CscMatrix::try_from_pattern_and_values(pattern.clone(), vec![1; pattern.nnz()]).unwrap();
vec![1; pattern.nnz()])
.unwrap();
DMatrix::from(&boolean_csc) DMatrix::from(&boolean_csc)
} }
@ -62,14 +64,23 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
let trans_strategy = trans_strategy(); let trans_strategy = trans_strategy();
let c_matrix_strategy = matrix(value_strategy.clone(), c_rows, c_cols); let c_matrix_strategy = matrix(value_strategy.clone(), c_rows, c_cols);
(c_matrix_strategy, common_dim, trans_strategy.clone(), trans_strategy.clone()) (
c_matrix_strategy,
common_dim,
trans_strategy.clone(),
trans_strategy.clone(),
)
.prop_flat_map(move |(c, common_dim, trans_a, trans_b)| { .prop_flat_map(move |(c, common_dim, trans_a, trans_b)| {
let a_shape = let a_shape = if trans_a {
if trans_a { (common_dim, c.nrows()) } (common_dim, c.nrows())
else { (c.nrows(), common_dim) }; } else {
let b_shape = (c.nrows(), common_dim)
if trans_b { (c.ncols(), common_dim) } };
else { (common_dim, c.ncols()) }; let b_shape = if trans_b {
(c.ncols(), common_dim)
} else {
(common_dim, c.ncols())
};
let a = csr(value_strategy.clone(), a_shape.0, a_shape.1, max_nnz); let a = csr(value_strategy.clone(), a_shape.0, a_shape.1, max_nnz);
let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1); let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1);
@ -78,29 +89,35 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
let beta = value_strategy.clone(); let beta = value_strategy.clone();
(Just(c), beta, alpha, Just(trans_a), a, Just(trans_b), b) (Just(c), beta, alpha, Just(trans_a), a, Just(trans_b), b)
}).prop_map(|(c, beta, alpha, trans_a, a, trans_b, b)| { })
SpmmCsrDenseArgs { .prop_map(
|(c, beta, alpha, trans_a, a, trans_b, b)| SpmmCsrDenseArgs {
c, c,
beta, beta,
alpha, alpha,
a: if trans_a { Op::Transpose(a) } else { Op::NoOp(a) }, a: if trans_a {
b: if trans_b { Op::Transpose(b) } else { Op::NoOp(b) }, Op::Transpose(a)
} } else {
}) Op::NoOp(a)
},
b: if trans_b {
Op::Transpose(b)
} else {
Op::NoOp(b)
},
},
)
} }
/// Returns matrices C, A and B with compatible dimensions such that it can be used /// Returns matrices C, A and B with compatible dimensions such that it can be used
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`. /// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
fn spmm_csc_dense_args_strategy() -> impl Strategy<Value = SpmmCscDenseArgs<i32>> { fn spmm_csc_dense_args_strategy() -> impl Strategy<Value = SpmmCscDenseArgs<i32>> {
spmm_csr_dense_args_strategy() spmm_csr_dense_args_strategy().prop_map(|args| SpmmCscDenseArgs {
.prop_map(|args| {
SpmmCscDenseArgs {
c: args.c, c: args.c,
beta: args.beta, beta: args.beta,
alpha: args.alpha, alpha: args.alpha,
a: args.a.map_same_op(|a| CscMatrix::from(&a)), a: args.a.map_same_op(|a| CscMatrix::from(&a)),
b: args.b b: args.b,
}
}) })
} }
@ -131,28 +148,46 @@ fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>>
let c_values = vec![value_strategy.clone(); c_pattern.nnz()]; let c_values = vec![value_strategy.clone(); c_pattern.nnz()];
let alpha = value_strategy.clone(); let alpha = value_strategy.clone();
let beta = value_strategy.clone(); let beta = value_strategy.clone();
(Just(c_pattern), Just(a_pattern), c_values, a_values, alpha, beta, trans_strategy()) (
}).prop_map(|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| { Just(c_pattern),
Just(a_pattern),
c_values,
a_values,
alpha,
beta,
trans_strategy(),
)
})
.prop_map(
|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| {
let c = CsrMatrix::try_from_pattern_and_values(c_pattern, c_values).unwrap(); let c = CsrMatrix::try_from_pattern_and_values(c_pattern, c_values).unwrap();
let a = CsrMatrix::try_from_pattern_and_values(a_pattern, a_values).unwrap(); let a = CsrMatrix::try_from_pattern_and_values(a_pattern, a_values).unwrap();
let a = if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) }; let a = if trans_a {
Op::Transpose(a.transpose())
} else {
Op::NoOp(a)
};
SpaddCsrArgs { c, beta, alpha, a } SpaddCsrArgs { c, beta, alpha, a }
}) },
)
} }
fn spadd_csc_prealloc_args_strategy() -> impl Strategy<Value = SpaddCscArgs<i32>> { fn spadd_csc_prealloc_args_strategy() -> impl Strategy<Value = SpaddCscArgs<i32>> {
spadd_csr_prealloc_args_strategy() spadd_csr_prealloc_args_strategy().prop_map(|args| SpaddCscArgs {
.prop_map(|args| SpaddCscArgs {
c: CscMatrix::from(&args.c), c: CscMatrix::from(&args.c),
beta: args.beta, beta: args.beta,
alpha: args.alpha, alpha: args.alpha,
a: args.a.map_same_op(|a| CscMatrix::from(&a)) a: args.a.map_same_op(|a| CscMatrix::from(&a)),
}) })
} }
fn dense_strategy() -> impl Strategy<Value = DMatrix<i32>> { fn dense_strategy() -> impl Strategy<Value = DMatrix<i32>> {
matrix(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM) matrix(
PROPTEST_I32_VALUE_STRATEGY,
PROPTEST_MATRIX_DIM,
PROPTEST_MATRIX_DIM,
)
} }
fn trans_strategy() -> impl Strategy<Value = bool> + Clone { fn trans_strategy() -> impl Strategy<Value = bool> + Clone {
@ -163,11 +198,12 @@ fn trans_strategy() -> impl Strategy<Value=bool> + Clone {
/// values. /// values.
fn op_strategy<S: Strategy>(strategy: S) -> impl Strategy<Value = Op<S::Value>> { fn op_strategy<S: Strategy>(strategy: S) -> impl Strategy<Value = Op<S::Value>> {
let is_transposed = proptest::bool::ANY; let is_transposed = proptest::bool::ANY;
(strategy, is_transposed) (strategy, is_transposed).prop_map(|(obj, is_trans)| {
.prop_map(|(obj, is_trans)| if is_trans { if is_trans {
Op::Transpose(obj) Op::Transpose(obj)
} else { } else {
Op::NoOp(obj) Op::NoOp(obj)
}
}) })
} }
@ -177,8 +213,7 @@ fn pattern_strategy() -> impl Strategy<Value=SparsityPattern> {
/// Constructs pairs (a, b) where a and b have the same dimensions /// Constructs pairs (a, b) where a and b have the same dimensions
fn spadd_pattern_strategy() -> impl Strategy<Value = (SparsityPattern, SparsityPattern)> { fn spadd_pattern_strategy() -> impl Strategy<Value = (SparsityPattern, SparsityPattern)> {
pattern_strategy() pattern_strategy().prop_flat_map(|a| {
.prop_flat_map(|a| {
let b = sparsity_pattern(a.major_dim(), a.minor_dim(), PROPTEST_MAX_NNZ); let b = sparsity_pattern(a.major_dim(), a.minor_dim(), PROPTEST_MAX_NNZ);
(Just(a), b) (Just(a), b)
}) })
@ -186,8 +221,7 @@ fn spadd_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPat
/// Constructs pairs (a, b) where a and b have compatible dimensions for a matrix product /// Constructs pairs (a, b) where a and b have compatible dimensions for a matrix product
fn spmm_csr_pattern_strategy() -> impl Strategy<Value = (SparsityPattern, SparsityPattern)> { fn spmm_csr_pattern_strategy() -> impl Strategy<Value = (SparsityPattern, SparsityPattern)> {
pattern_strategy() pattern_strategy().prop_flat_map(|a| {
.prop_flat_map(|a| {
let b = sparsity_pattern(a.minor_dim(), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ); let b = sparsity_pattern(a.minor_dim(), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
(Just(a), b) (Just(a), b)
}) })
@ -218,48 +252,55 @@ fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()]; let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()];
let c_pattern = spmm_csr_pattern(&a_pattern, &b_pattern); let c_pattern = spmm_csr_pattern(&a_pattern, &b_pattern);
let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()]; let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()];
let a = a_values.prop_map(move |values| let a = a_values.prop_map(move |values| {
CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap()); CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap()
let b = b_values.prop_map(move |values| });
CsrMatrix::try_from_pattern_and_values(b_pattern.clone(), values).unwrap()); let b = b_values.prop_map(move |values| {
let c = c_values.prop_map(move |values| CsrMatrix::try_from_pattern_and_values(b_pattern.clone(), values).unwrap()
CsrMatrix::try_from_pattern_and_values(c_pattern.clone(), values).unwrap()); });
let c = c_values.prop_map(move |values| {
CsrMatrix::try_from_pattern_and_values(c_pattern.clone(), values).unwrap()
});
let alpha = PROPTEST_I32_VALUE_STRATEGY; let alpha = PROPTEST_I32_VALUE_STRATEGY;
let beta = PROPTEST_I32_VALUE_STRATEGY; let beta = PROPTEST_I32_VALUE_STRATEGY;
(c, beta, alpha, trans_strategy(), a, trans_strategy(), b) (c, beta, alpha, trans_strategy(), a, trans_strategy(), b)
}) })
.prop_map(|(c, beta, alpha, trans_a, a, trans_b, b)| { .prop_map(
SpmmCsrArgs::<i32> { |(c, beta, alpha, trans_a, a, trans_b, b)| SpmmCsrArgs::<i32> {
c, c,
beta, beta,
alpha, alpha,
a: if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) }, a: if trans_a {
b: if trans_b { Op::Transpose(b.transpose()) } else { Op::NoOp(b) } Op::Transpose(a.transpose())
} } else {
}) Op::NoOp(a)
},
b: if trans_b {
Op::Transpose(b.transpose())
} else {
Op::NoOp(b)
},
},
)
} }
fn spmm_csc_prealloc_args_strategy() -> impl Strategy<Value = SpmmCscArgs<i32>> { fn spmm_csc_prealloc_args_strategy() -> impl Strategy<Value = SpmmCscArgs<i32>> {
// Note: Converting from CSR is simple, but might be significantly slower than // Note: Converting from CSR is simple, but might be significantly slower than
// writing a common implementation that can be shared between CSR and CSC args // writing a common implementation that can be shared between CSR and CSC args
spmm_csr_prealloc_args_strategy() spmm_csr_prealloc_args_strategy().prop_map(|args| SpmmCscArgs {
.prop_map(|args| {
SpmmCscArgs {
c: CscMatrix::from(&args.c), c: CscMatrix::from(&args.c),
beta: args.beta, beta: args.beta,
alpha: args.alpha, alpha: args.alpha,
a: args.a.map_same_op(|a| CscMatrix::from(&a)), a: args.a.map_same_op(|a| CscMatrix::from(&a)),
b: args.b.map_same_op(|b| CscMatrix::from(&b)) b: args.b.map_same_op(|b| CscMatrix::from(&b)),
}
}) })
} }
fn csc_invertible_diagonal() -> impl Strategy<Value = CscMatrix<f64>> { fn csc_invertible_diagonal() -> impl Strategy<Value = CscMatrix<f64>> {
let non_zero_values = value_strategy::<f64>() let non_zero_values =
.prop_filter("Only non-zeros values accepted", |x| x != &0.0); value_strategy::<f64>().prop_filter("Only non-zeros values accepted", |x| x != &0.0);
vector(non_zero_values, PROPTEST_MATRIX_DIM) vector(non_zero_values, PROPTEST_MATRIX_DIM).prop_map(|d| {
.prop_map(|d| {
let mut matrix = CscMatrix::identity(d.len()); let mut matrix = CscMatrix::identity(d.len());
matrix.values_mut().clone_from_slice(&d.as_slice()); matrix.values_mut().clone_from_slice(&d.as_slice());
matrix matrix
@ -267,9 +308,13 @@ fn csc_invertible_diagonal() -> impl Strategy<Value=CscMatrix<f64>> {
} }
fn csc_square_with_non_zero_diagonals() -> impl Strategy<Value = CscMatrix<f64>> { fn csc_square_with_non_zero_diagonals() -> impl Strategy<Value = CscMatrix<f64>> {
csc_invertible_diagonal() csc_invertible_diagonal().prop_flat_map(|d| {
.prop_flat_map(|d| { csc(
csc(value_strategy::<f64>(), d.nrows(), d.nrows(), PROPTEST_MAX_NNZ) value_strategy::<f64>(),
d.nrows(),
d.nrows(),
PROPTEST_MAX_NNZ,
)
.prop_map(move |mut c| { .prop_map(move |mut c| {
for (i, j, v) in c.triplet_iter_mut() { for (i, j, v) in c.triplet_iter_mut() {
if i == j { if i == j {
@ -285,12 +330,13 @@ fn csc_square_with_non_zero_diagonals() -> impl Strategy<Value=CscMatrix<f64>> {
} }
/// Helper function to help us call dense GEMM with our `Op` type /// Helper function to help us call dense GEMM with our `Op` type
fn dense_gemm<'a>(beta: i32, fn dense_gemm<'a>(
beta: i32,
c: impl Into<DMatrixSliceMut<'a, i32>>, c: impl Into<DMatrixSliceMut<'a, i32>>,
alpha: i32, alpha: i32,
a: Op<impl Into<DMatrixSlice<'a, i32>>>, a: Op<impl Into<DMatrixSlice<'a, i32>>>,
b: Op<impl Into<DMatrixSlice<'a, i32>>>) b: Op<impl Into<DMatrixSlice<'a, i32>>>,
{ ) {
let mut c = c.into(); let mut c = c.into();
let a = a.convert(); let a = a.convert();
let b = b.convert(); let b = b.convert();
@ -300,7 +346,7 @@ fn dense_gemm<'a>(beta: i32,
(NoOp(a), NoOp(b)) => c.gemm(alpha, &a, &b, beta), (NoOp(a), NoOp(b)) => c.gemm(alpha, &a, &b, beta),
(Transpose(a), NoOp(b)) => c.gemm(alpha, &a.transpose(), &b, beta), (Transpose(a), NoOp(b)) => c.gemm(alpha, &a.transpose(), &b, beta),
(NoOp(a), Transpose(b)) => c.gemm(alpha, &a, &b.transpose(), beta), (NoOp(a), Transpose(b)) => c.gemm(alpha, &a, &b.transpose(), beta),
(Transpose(a), Transpose(b)) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta) (Transpose(a), Transpose(b)) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta),
} }
} }

View File

@ -7,10 +7,8 @@ fn sparsity_pattern_valid_data() {
{ {
// A pattern with zero explicitly stored entries // A pattern with zero explicitly stored entries
let pattern = SparsityPattern::try_from_offsets_and_indices(3, let pattern =
2, SparsityPattern::try_from_offsets_and_indices(3, 2, vec![0, 0, 0, 0], Vec::new())
vec![0, 0, 0, 0],
Vec::new())
.unwrap(); .unwrap();
assert_eq!(pattern.major_dim(), 3); assert_eq!(pattern.major_dim(), 3);
@ -46,8 +44,10 @@ fn sparsity_pattern_valid_data() {
assert_eq!(pattern.lane(0), &[0, 5]); assert_eq!(pattern.lane(0), &[0, 5]);
assert_eq!(pattern.lane(1), &[]); assert_eq!(pattern.lane(1), &[]);
assert_eq!(pattern.lane(2), &[1, 2, 3]); assert_eq!(pattern.lane(2), &[1, 2, 3]);
assert_eq!(pattern.entries().collect::<Vec<_>>(), assert_eq!(
vec![(0, 0), (0, 5), (2, 1), (2, 2), (2, 3)]); pattern.entries().collect::<Vec<_>>(),
vec![(0, 0), (0, 5), (2, 1), (2, 2), (2, 3)]
);
let (offsets2, indices2) = pattern.disassemble(); let (offsets2, indices2) = pattern.disassemble();
assert_eq!(offsets2, offsets); assert_eq!(offsets2, offsets);
@ -60,7 +60,10 @@ fn sparsity_pattern_try_from_invalid_data() {
{ {
// Empty offset array (invalid length) // Empty offset array (invalid length)
let pattern = SparsityPattern::try_from_offsets_and_indices(0, 0, Vec::new(), Vec::new()); let pattern = SparsityPattern::try_from_offsets_and_indices(0, 0, Vec::new(), Vec::new());
assert_eq!(pattern, Err(SparsityPatternFormatError::InvalidOffsetArrayLength)); assert_eq!(
pattern,
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
);
} }
{ {
@ -69,7 +72,10 @@ fn sparsity_pattern_try_from_invalid_data() {
let indices = vec![0, 1, 2, 3, 5]; let indices = vec![0, 1, 2, 3, 5];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices); let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetArrayLength))); assert!(matches!(
pattern,
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
));
} }
{ {
@ -77,7 +83,10 @@ fn sparsity_pattern_try_from_invalid_data() {
let offsets = vec![1, 2, 2, 5]; let offsets = vec![1, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3]; let indices = vec![0, 5, 1, 2, 3];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices); let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetFirstLast))); assert!(matches!(
pattern,
Err(SparsityPatternFormatError::InvalidOffsetFirstLast)
));
} }
{ {
@ -85,7 +94,10 @@ fn sparsity_pattern_try_from_invalid_data() {
let offsets = vec![0, 2, 2, 4]; let offsets = vec![0, 2, 2, 4];
let indices = vec![0, 5, 1, 2, 3]; let indices = vec![0, 5, 1, 2, 3];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices); let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetFirstLast))); assert!(matches!(
pattern,
Err(SparsityPatternFormatError::InvalidOffsetFirstLast)
));
} }
{ {
@ -93,7 +105,10 @@ fn sparsity_pattern_try_from_invalid_data() {
let offsets = vec![0, 2, 2]; let offsets = vec![0, 2, 2];
let indices = vec![0, 5, 1, 2, 3]; let indices = vec![0, 5, 1, 2, 3];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices); let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetArrayLength))); assert!(matches!(
pattern,
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
));
} }
{ {
@ -101,7 +116,10 @@ fn sparsity_pattern_try_from_invalid_data() {
let offsets = vec![0, 3, 2, 5]; let offsets = vec![0, 3, 2, 5];
let indices = vec![0, 1, 2, 3, 4]; let indices = vec![0, 1, 2, 3, 4];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices); let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert_eq!(pattern, Err(SparsityPatternFormatError::NonmonotonicOffsets)); assert_eq!(
pattern,
Err(SparsityPatternFormatError::NonmonotonicOffsets)
);
} }
{ {
@ -109,7 +127,10 @@ fn sparsity_pattern_try_from_invalid_data() {
let offsets = vec![0, 2, 2, 5]; let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 2, 3, 1, 4]; let indices = vec![0, 2, 3, 1, 4];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices); let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert_eq!(pattern, Err(SparsityPatternFormatError::NonmonotonicMinorIndices)); assert_eq!(
pattern,
Err(SparsityPatternFormatError::NonmonotonicMinorIndices)
);
} }
{ {
@ -117,7 +138,10 @@ fn sparsity_pattern_try_from_invalid_data() {
let offsets = vec![0, 2, 2, 5]; let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 6, 1, 2, 3]; let indices = vec![0, 6, 1, 2, 3];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices); let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert_eq!(pattern, Err(SparsityPatternFormatError::MinorIndexOutOfBounds)); assert_eq!(
pattern,
Err(SparsityPatternFormatError::MinorIndexOutOfBounds)
);
} }
{ {

View File

@ -6,25 +6,27 @@ fn coo_no_duplicates_generates_admissible_matrices() {
#[cfg(feature = "slow-tests")] #[cfg(feature = "slow-tests")]
mod slow { mod slow {
use nalgebra_sparse::proptest::{coo_with_duplicates, coo_no_duplicates, csr, csc, sparsity_pattern};
use nalgebra::DMatrix; use nalgebra::DMatrix;
use nalgebra_sparse::proptest::{
coo_no_duplicates, coo_with_duplicates, csc, csr, sparsity_pattern,
};
use proptest::test_runner::TestRunner;
use proptest::strategy::ValueTree;
use itertools::Itertools; use itertools::Itertools;
use proptest::strategy::ValueTree;
use proptest::test_runner::TestRunner;
use proptest::prelude::*; use proptest::prelude::*;
use nalgebra_sparse::csr::CsrMatrix;
use std::collections::HashSet; use std::collections::HashSet;
use std::iter::repeat; use std::iter::repeat;
use std::ops::RangeInclusive; use std::ops::RangeInclusive;
use nalgebra_sparse::csr::CsrMatrix;
fn generate_all_possible_matrices(value_range: RangeInclusive<i32>, fn generate_all_possible_matrices(
value_range: RangeInclusive<i32>,
rows_range: RangeInclusive<usize>, rows_range: RangeInclusive<usize>,
cols_range: RangeInclusive<usize>) cols_range: RangeInclusive<usize>,
-> HashSet<DMatrix<i32>> ) -> HashSet<DMatrix<i32>> {
{
// Enumerate all possible combinations // Enumerate all possible combinations
let mut all_combinations = HashSet::new(); let mut all_combinations = HashSet::new();
for nrows in rows_range { for nrows in rows_range {
@ -48,7 +50,11 @@ mod slow {
.take(n_values) .take(n_values)
.multi_cartesian_product(); .multi_cartesian_product();
for matrix_values in values_iter { for matrix_values in values_iter {
all_combinations.insert(DMatrix::from_row_slice(nrows, ncols, &matrix_values)); all_combinations.insert(DMatrix::from_row_slice(
nrows,
ncols,
&matrix_values,
));
} }
} }
} }
@ -80,12 +86,14 @@ mod slow {
// Enumerate all possible combinations // Enumerate all possible combinations
let all_combinations = generate_all_possible_matrices(values, rows, cols); let all_combinations = generate_all_possible_matrices(values, rows, cols);
let visited_combinations = sample_matrix_output_space(strategy, let visited_combinations =
&mut runner, sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
num_generated_matrices);
assert_eq!(visited_combinations.len(), all_combinations.len()); assert_eq!(visited_combinations.len(), all_combinations.len());
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values."); assert_eq!(
visited_combinations, all_combinations,
"Did not sample all possible values."
);
} }
#[cfg(feature = "slow-tests")] #[cfg(feature = "slow-tests")]
@ -113,9 +121,8 @@ mod slow {
// `coo_with_duplicates`) // `coo_with_duplicates`)
let all_combinations = generate_all_possible_matrices(values, rows, cols); let all_combinations = generate_all_possible_matrices(values, rows, cols);
let visited_combinations = sample_matrix_output_space(strategy, let visited_combinations =
&mut runner, sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
num_generated_matrices);
// Here we cannot verify that the set of visited combinations is *equal* to // Here we cannot verify that the set of visited combinations is *equal* to
// all possible outcomes with the given constraints, however the // all possible outcomes with the given constraints, however the
@ -143,12 +150,14 @@ mod slow {
let all_combinations = generate_all_possible_matrices(values, rows, cols); let all_combinations = generate_all_possible_matrices(values, rows, cols);
let visited_combinations = sample_matrix_output_space(strategy, let visited_combinations =
&mut runner, sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
num_generated_matrices);
assert_eq!(visited_combinations.len(), all_combinations.len()); assert_eq!(visited_combinations.len(), all_combinations.len());
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values."); assert_eq!(
visited_combinations, all_combinations,
"Did not sample all possible values."
);
} }
#[cfg(feature = "slow-tests")] #[cfg(feature = "slow-tests")]
@ -169,12 +178,14 @@ mod slow {
let all_combinations = generate_all_possible_matrices(values, rows, cols); let all_combinations = generate_all_possible_matrices(values, rows, cols);
let visited_combinations = sample_matrix_output_space(strategy, let visited_combinations =
&mut runner, sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
num_generated_matrices);
assert_eq!(visited_combinations.len(), all_combinations.len()); assert_eq!(visited_combinations.len(), all_combinations.len());
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values."); assert_eq!(
visited_combinations, all_combinations,
"Did not sample all possible values."
);
} }
#[cfg(feature = "slow-tests")] #[cfg(feature = "slow-tests")]
@ -206,13 +217,14 @@ mod slow {
assert_eq!(visited_patterns, all_possible_patterns); assert_eq!(visited_patterns, all_possible_patterns);
} }
fn sample_matrix_output_space<S>(strategy: S, fn sample_matrix_output_space<S>(
strategy: S,
runner: &mut TestRunner, runner: &mut TestRunner,
num_samples: usize) num_samples: usize,
-> HashSet<DMatrix<i32>> ) -> HashSet<DMatrix<i32>>
where where
S: Strategy, S: Strategy,
DMatrix<i32>: for<'b> From<&'b S::Value> DMatrix<i32>: for<'b> From<&'b S::Value>,
{ {
sample_strategy(strategy, runner) sample_strategy(strategy, runner)
.take(num_samples) .take(num_samples)
@ -220,8 +232,10 @@ mod slow {
.collect() .collect()
} }
fn sample_strategy<'a, S: 'a + Strategy>(strategy: S, runner: &'a mut TestRunner) fn sample_strategy<'a, S: 'a + Strategy>(
-> impl 'a + Iterator<Item=S::Value> { strategy: S,
runner: &'a mut TestRunner,
) -> impl 'a + Iterator<Item = S::Value> {
repeat(()).map(move |_| { repeat(()).map(move |_| {
let tree = strategy let tree = strategy
.new_tree(runner) .new_tree(runner)