This commit is contained in:
Andreas Longva 2021-01-25 17:26:27 +01:00
parent 795d818ae5
commit 7473d54d74
30 changed files with 1477 additions and 1069 deletions

View File

@ -1,17 +1,17 @@
use crate::coo::CooMatrix;
use crate::convert::serial::*;
use nalgebra::{Matrix, Scalar, Dim, ClosedAdd, DMatrix};
use nalgebra::storage::{Storage};
use num_traits::Zero;
use crate::csr::CsrMatrix;
use crate::coo::CooMatrix;
use crate::csc::CscMatrix;
use crate::csr::CsrMatrix;
use nalgebra::storage::Storage;
use nalgebra::{ClosedAdd, DMatrix, Dim, Matrix, Scalar};
use num_traits::Zero;
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CooMatrix<T>
where
T: Scalar + Zero,
R: Dim,
C: Dim,
S: Storage<T, R, C>
S: Storage<T, R, C>,
{
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
convert_dense_coo(matrix)
@ -29,7 +29,7 @@ where
impl<'a, T> From<&'a CooMatrix<T>> for CsrMatrix<T>
where
T: Scalar + Zero + ClosedAdd
T: Scalar + Zero + ClosedAdd,
{
fn from(matrix: &'a CooMatrix<T>) -> Self {
convert_coo_csr(matrix)
@ -38,7 +38,7 @@ where
impl<'a, T> From<&'a CsrMatrix<T>> for CooMatrix<T>
where
T: Scalar + Zero + ClosedAdd
T: Scalar + Zero + ClosedAdd,
{
fn from(matrix: &'a CsrMatrix<T>) -> Self {
convert_csr_coo(matrix)
@ -50,7 +50,7 @@ where
T: Scalar + Zero,
R: Dim,
C: Dim,
S: Storage<T, R, C>
S: Storage<T, R, C>,
{
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
convert_dense_csr(matrix)
@ -59,7 +59,7 @@ where
impl<'a, T> From<&'a CsrMatrix<T>> for DMatrix<T>
where
T: Scalar + Zero + ClosedAdd
T: Scalar + Zero + ClosedAdd,
{
fn from(matrix: &'a CsrMatrix<T>) -> Self {
convert_csr_dense(matrix)
@ -68,7 +68,7 @@ where
impl<'a, T> From<&'a CooMatrix<T>> for CscMatrix<T>
where
T: Scalar + Zero + ClosedAdd
T: Scalar + Zero + ClosedAdd,
{
fn from(matrix: &'a CooMatrix<T>) -> Self {
convert_coo_csc(matrix)
@ -77,7 +77,7 @@ where
impl<'a, T> From<&'a CscMatrix<T>> for CooMatrix<T>
where
T: Scalar + Zero
T: Scalar + Zero,
{
fn from(matrix: &'a CscMatrix<T>) -> Self {
convert_csc_coo(matrix)
@ -89,7 +89,7 @@ impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CscMatrix<T>
T: Scalar + Zero,
R: Dim,
C: Dim,
S: Storage<T, R, C>
S: Storage<T, R, C>,
{
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
convert_dense_csc(matrix)
@ -98,7 +98,7 @@ impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CscMatrix<T>
impl<'a, T> From<&'a CscMatrix<T>> for DMatrix<T>
where
T: Scalar + Zero + ClosedAdd
T: Scalar + Zero + ClosedAdd,
{
fn from(matrix: &'a CscMatrix<T>) -> Self {
convert_csc_dense(matrix)
@ -107,7 +107,7 @@ impl<'a, T> From<&'a CscMatrix<T>> for DMatrix<T>
impl<'a, T> From<&'a CscMatrix<T>> for CsrMatrix<T>
where
T: Scalar
T: Scalar,
{
fn from(matrix: &'a CscMatrix<T>) -> Self {
convert_csc_csr(matrix)
@ -116,7 +116,7 @@ impl<'a, T> From<&'a CscMatrix<T>> for CsrMatrix<T>
impl<'a, T> From<&'a CsrMatrix<T>> for CscMatrix<T>
where
T: Scalar
T: Scalar,
{
fn from(matrix: &'a CsrMatrix<T>) -> Self {
convert_csr_csc(matrix)

View File

@ -7,8 +7,8 @@ use std::ops::Add;
use num_traits::Zero;
use nalgebra::{ClosedAdd, Dim, DMatrix, Matrix, Scalar};
use nalgebra::storage::Storage;
use nalgebra::{ClosedAdd, DMatrix, Dim, Matrix, Scalar};
use crate::coo::CooMatrix;
use crate::cs;
@ -21,7 +21,7 @@ where
T: Scalar + Zero,
R: Dim,
C: Dim,
S: Storage<T, R, C>
S: Storage<T, R, C>,
{
let mut coo = CooMatrix::new(dense.nrows(), dense.ncols());
@ -52,12 +52,14 @@ where
/// Converts a [`CooMatrix`] to a [`CsrMatrix`].
pub fn convert_coo_csr<T>(coo: &CooMatrix<T>) -> CsrMatrix<T>
where
T: Scalar + Zero
T: Scalar + Zero,
{
let (offsets, indices, values) = convert_coo_cs(coo.nrows(),
let (offsets, indices, values) = convert_coo_cs(
coo.nrows(),
coo.row_indices(),
coo.col_indices(),
coo.values());
coo.values(),
);
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
// to see if it can be justified for performance reasons)
@ -66,8 +68,7 @@ where
}
/// Converts a [`CsrMatrix`] to a [`CooMatrix`].
pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T>
{
pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T> {
let mut result = CooMatrix::new(csr.nrows(), csr.ncols());
for (i, j, v) in csr.triplet_iter() {
result.push(i, j, v.inlined_clone());
@ -78,7 +79,7 @@ pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T>
/// Converts a [`CsrMatrix`] to a dense matrix.
pub fn convert_csr_dense<T>(csr: &CsrMatrix<T>) -> DMatrix<T>
where
T: Scalar + ClosedAdd + Zero
T: Scalar + ClosedAdd + Zero,
{
let mut output = DMatrix::zeros(csr.nrows(), csr.ncols());
@ -95,7 +96,7 @@ where
T: Scalar + Zero,
R: Dim,
C: Dim,
S: Storage<T, R, C>
S: Storage<T, R, C>,
{
let mut row_offsets = Vec::with_capacity(dense.nrows() + 1);
let mut col_idx = Vec::new();
@ -125,12 +126,14 @@ where
/// Converts a [`CooMatrix`] to a [`CscMatrix`].
pub fn convert_coo_csc<T>(coo: &CooMatrix<T>) -> CscMatrix<T>
where
T: Scalar + Zero
T: Scalar + Zero,
{
let (offsets, indices, values) = convert_coo_cs(coo.ncols(),
let (offsets, indices, values) = convert_coo_cs(
coo.ncols(),
coo.col_indices(),
coo.row_indices(),
coo.values());
coo.values(),
);
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
// to see if it can be justified for performance reasons)
@ -141,7 +144,7 @@ where
/// Converts a [`CscMatrix`] to a [`CooMatrix`].
pub fn convert_csc_coo<T>(csc: &CscMatrix<T>) -> CooMatrix<T>
where
T: Scalar
T: Scalar,
{
let mut coo = CooMatrix::new(csc.nrows(), csc.ncols());
for (i, j, v) in csc.triplet_iter() {
@ -153,7 +156,7 @@ where
/// Converts a [`CscMatrix`] to a dense matrix.
pub fn convert_csc_dense<T>(csc: &CscMatrix<T>) -> DMatrix<T>
where
T: Scalar + ClosedAdd + Zero
T: Scalar + ClosedAdd + Zero,
{
let mut output = DMatrix::zeros(csc.nrows(), csc.ncols());
@ -170,7 +173,7 @@ pub fn convert_dense_csc<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CscMatrix<T>
T: Scalar + Zero,
R: Dim,
C: Dim,
S: Storage<T, R, C>
S: Storage<T, R, C>,
{
let mut col_offsets = Vec::with_capacity(dense.ncols() + 1);
let mut row_idx = Vec::new();
@ -197,13 +200,15 @@ pub fn convert_dense_csc<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CscMatrix<T>
/// Converts a [`CsrMatrix`] to a [`CscMatrix`].
pub fn convert_csr_csc<T>(csr: &CsrMatrix<T>) -> CscMatrix<T>
where
T: Scalar
T: Scalar,
{
let (offsets, indices, values) = cs::transpose_cs(csr.nrows(),
let (offsets, indices, values) = cs::transpose_cs(
csr.nrows(),
csr.ncols(),
csr.row_offsets(),
csr.col_indices(),
csr.values());
csr.values(),
);
// TODO: Avoid data validity check?
CscMatrix::try_from_csc_data(csr.nrows(), csr.ncols(), offsets, indices, values)
@ -213,26 +218,29 @@ where
/// Converts a [`CscMatrix`] to a [`CsrMatrix`].
pub fn convert_csc_csr<T>(csc: &CscMatrix<T>) -> CsrMatrix<T>
where
T: Scalar
T: Scalar,
{
let (offsets, indices, values) = cs::transpose_cs(csc.ncols(),
let (offsets, indices, values) = cs::transpose_cs(
csc.ncols(),
csc.nrows(),
csc.col_offsets(),
csc.row_indices(),
csc.values());
csc.values(),
);
// TODO: Avoid data validity check?
CsrMatrix::try_from_csr_data(csc.nrows(), csc.ncols(), offsets, indices, values)
.expect("Internal error: Invalid CSR data during CSC->CSR conversion")
}
fn convert_coo_cs<T>(major_dim: usize,
fn convert_coo_cs<T>(
major_dim: usize,
major_indices: &[usize],
minor_indices: &[usize],
values: &[T])
-> (Vec<usize>, Vec<usize>, Vec<T>)
values: &[T],
) -> (Vec<usize>, Vec<usize>, Vec<T>)
where
T: Scalar + Zero
T: Scalar + Zero,
{
assert_eq!(major_indices.len(), minor_indices.len());
assert_eq!(minor_indices.len(), values.len());

View File

@ -45,8 +45,7 @@ pub struct CooMatrix<T> {
values: Vec<T>,
}
impl<T> CooMatrix<T>
{
impl<T> CooMatrix<T> {
/// Construct a zero COO matrix of the given dimensions.
///
/// Specifically, the collection of triplets - corresponding to explicitly stored entries -
@ -78,11 +77,13 @@ impl<T> CooMatrix<T>
use crate::SparseFormatErrorKind::*;
if row_indices.len() != col_indices.len() {
return Err(SparseFormatError::from_kind_and_msg(
InvalidStructure, "Number of row and col indices must be the same."
InvalidStructure,
"Number of row and col indices must be the same.",
));
} else if col_indices.len() != values.len() {
return Err(SparseFormatError::from_kind_and_msg(
InvalidStructure, "Number of col indices and values must be the same."
InvalidStructure,
"Number of col indices and values must be the same.",
));
}
@ -90,9 +91,15 @@ impl<T> CooMatrix<T>
let col_indices_in_bounds = col_indices.iter().all(|j| *j < ncols);
if !row_indices_in_bounds {
Err(SparseFormatError::from_kind_and_msg(IndexOutOfBounds, "Row index out of bounds."))
Err(SparseFormatError::from_kind_and_msg(
IndexOutOfBounds,
"Row index out of bounds.",
))
} else if !col_indices_in_bounds {
Err(SparseFormatError::from_kind_and_msg(IndexOutOfBounds, "Col index out of bounds."))
Err(SparseFormatError::from_kind_and_msg(
IndexOutOfBounds,
"Col index out of bounds.",
))
} else {
Ok(Self {
nrows,

View File

@ -5,8 +5,8 @@ use num_traits::One;
use nalgebra::Scalar;
use crate::{SparseEntry, SparseEntryMut};
use crate::pattern::SparsityPattern;
use crate::{SparseEntry, SparseEntryMut};
/// An abstract compressed matrix.
///
@ -18,7 +18,7 @@ use crate::pattern::SparsityPattern;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsMatrix<T> {
sparsity_pattern: SparsityPattern,
values: Vec<T>
values: Vec<T>,
}
impl<T> CsMatrix<T> {
@ -50,14 +50,22 @@ impl<T> CsMatrix<T> {
#[inline]
pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) {
let pattern = self.pattern();
(pattern.major_offsets(), pattern.minor_indices(), &self.values)
(
pattern.major_offsets(),
pattern.minor_indices(),
&self.values,
)
}
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
#[inline]
pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
let pattern = &mut self.sparsity_pattern;
(pattern.major_offsets(), pattern.minor_indices(), &mut self.values)
(
pattern.major_offsets(),
pattern.minor_indices(),
&mut self.values,
)
}
#[inline]
@ -66,9 +74,12 @@ impl<T> CsMatrix<T> {
}
#[inline]
pub fn from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
-> Self {
assert_eq!(pattern.nnz(), values.len(), "Internal error: consumers should verify shape compatibility.");
pub fn from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>) -> Self {
assert_eq!(
pattern.nnz(),
values.len(),
"Internal error: consumers should verify shape compatibility."
);
Self {
sparsity_pattern: pattern,
values,
@ -105,13 +116,21 @@ impl<T> CsMatrix<T> {
let (_, minor_indices, values) = self.cs_data();
let minor_indices = &minor_indices[row_range.clone()];
let values = &values[row_range];
get_entry_from_slices(self.pattern().minor_dim(), minor_indices, values, minor_index)
get_entry_from_slices(
self.pattern().minor_dim(),
minor_indices,
values,
minor_index,
)
}
/// Returns a mutable entry for the given major/minor indices, or `None` if the indices are out
/// of bounds.
pub fn get_entry_mut(&mut self, major_index: usize, minor_index: usize)
-> Option<SparseEntryMut<T>> {
pub fn get_entry_mut(
&mut self,
major_index: usize,
minor_index: usize,
) -> Option<SparseEntryMut<T>> {
let row_range = self.get_index_range(major_index)?;
let minor_dim = self.pattern().minor_dim();
let (_, minor_indices, values) = self.cs_data_mut();
@ -126,7 +145,7 @@ impl<T> CsMatrix<T> {
Some(CsLane {
minor_indices: &minor_indices[range.clone()],
values: &values[range],
minor_dim: self.pattern().minor_dim()
minor_dim: self.pattern().minor_dim(),
})
}
@ -138,7 +157,7 @@ impl<T> CsMatrix<T> {
Some(CsLaneMut {
minor_dim,
minor_indices: &minor_indices[range.clone()],
values: &mut values[range]
values: &mut values[range],
})
}
@ -156,7 +175,7 @@ impl<T> CsMatrix<T> {
pub fn filter<P>(&self, predicate: P) -> Self
where
T: Clone,
P: Fn(usize, usize, &T) -> bool
P: Fn(usize, usize, &T) -> bool,
{
let (major_dim, minor_dim) = (self.pattern().major_dim(), self.pattern().minor_dim());
let mut new_offsets = Vec::with_capacity(self.pattern().major_dim() + 1);
@ -180,7 +199,8 @@ impl<T> CsMatrix<T> {
major_dim,
minor_dim,
new_offsets,
new_indices)
new_indices,
)
.expect("Internal error: Sparsity pattern must always be valid.");
Self::from_pattern_and_values(new_pattern, new_values)
@ -189,7 +209,7 @@ impl<T> CsMatrix<T> {
/// Returns the diagonal of the matrix as a sparse matrix.
pub fn diagonal_as_matrix(&self) -> Self
where
T: Clone
T: Clone,
{
// TODO: This might be faster with a binary search for each diagonal entry
self.filter(|i, j, _| i == j)
@ -204,8 +224,8 @@ impl<T: Scalar + One> CsMatrix<T> {
let values = vec![T::one(); n];
// TODO: We should skip checks here
let pattern = SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices)
.unwrap();
let pattern =
SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices).unwrap();
Self::from_pattern_and_values(pattern, values)
}
}
@ -214,7 +234,8 @@ fn get_entry_from_slices<'a, T>(
minor_dim: usize,
minor_indices: &'a [usize],
values: &'a [T],
global_minor_index: usize) -> Option<SparseEntry<'a, T>> {
global_minor_index: usize,
) -> Option<SparseEntry<'a, T>> {
let local_index = minor_indices.binary_search(&global_minor_index);
if let Ok(local_index) = local_index {
Some(SparseEntry::NonZero(&values[local_index]))
@ -229,7 +250,8 @@ fn get_mut_entry_from_slices<'a, T>(
minor_dim: usize,
minor_indices: &'a [usize],
values: &'a mut [T],
global_minor_indices: usize) -> Option<SparseEntryMut<'a, T>> {
global_minor_indices: usize,
) -> Option<SparseEntryMut<'a, T>> {
let local_index = minor_indices.binary_search(&global_minor_indices);
if let Ok(local_index) = local_index {
Some(SparseEntryMut::NonZero(&mut values[local_index]))
@ -244,14 +266,14 @@ fn get_mut_entry_from_slices<'a, T>(
pub struct CsLane<'a, T> {
minor_dim: usize,
minor_indices: &'a [usize],
values: &'a [T]
values: &'a [T],
}
#[derive(Debug, PartialEq, Eq)]
pub struct CsLaneMut<'a, T> {
minor_dim: usize,
minor_indices: &'a [usize],
values: &'a mut [T]
values: &'a mut [T],
}
pub struct CsLaneIter<'a, T> {
@ -266,14 +288,14 @@ impl<'a, T> CsLaneIter<'a, T> {
Self {
current_lane_idx: 0,
pattern,
remaining_values: values
remaining_values: values,
}
}
}
impl<'a, T> Iterator for CsLaneIter<'a, T>
where
T: 'a
T: 'a,
{
type Item = CsLane<'a, T>;
@ -290,7 +312,7 @@ impl<'a, T> Iterator for CsLaneIter<'a, T>
Some(CsLane {
minor_dim,
minor_indices,
values: values_in_lane
values: values_in_lane,
})
} else {
None
@ -310,14 +332,14 @@ impl<'a, T> CsLaneIterMut<'a, T> {
Self {
current_lane_idx: 0,
pattern,
remaining_values: values
remaining_values: values,
}
}
}
impl<'a, T> Iterator for CsLaneIterMut<'a, T>
where
T: 'a
T: 'a,
{
type Item = CsLaneMut<'a, T>;
@ -336,7 +358,7 @@ impl<'a, T> Iterator for CsLaneIterMut<'a, T>
Some(CsLaneMut {
minor_dim,
minor_indices,
values: values_in_lane
values: values_in_lane,
})
} else {
None
@ -375,10 +397,11 @@ macro_rules! impl_cs_lane_common_methods {
self.minor_dim,
self.minor_indices,
self.values,
global_col_index)
}
global_col_index,
)
}
}
};
}
impl_cs_lane_common_methods!(CsLane<'a, T>);
@ -394,10 +417,12 @@ impl<'a, T> CsLaneMut<'a, T> {
}
pub fn get_entry_mut(&mut self, global_minor_index: usize) -> Option<SparseEntryMut<T>> {
get_mut_entry_from_slices(self.minor_dim,
get_mut_entry_from_slices(
self.minor_dim,
self.minor_indices,
self.values,
global_minor_index)
global_minor_index,
)
}
}
@ -405,7 +430,7 @@ impl<'a, T> CsLaneMut<'a, T> {
/// TODO: This doesn't belong here.
struct UninitVec<T> {
vec: Vec<T>,
len: usize
len: usize,
}
impl<T> UninitVec<T> {
@ -414,7 +439,7 @@ impl<T> UninitVec<T> {
vec: Vec::with_capacity(len),
// We need to store len separately, because for zero-sized types,
// Vec::with_capacity(len) does not give vec.capacity() == len
len
len,
}
}
@ -444,10 +469,10 @@ pub fn transpose_cs<T>(
minor_dim: usize,
source_major_offsets: &[usize],
source_minor_indices: &[usize],
values: &[T])
-> (Vec<usize>, Vec<usize>, Vec<T>)
values: &[T],
) -> (Vec<usize>, Vec<usize>, Vec<T>)
where
T: Scalar
T: Scalar,
{
assert_eq!(source_major_offsets.len(), major_dim + 1);
assert_eq!(source_minor_indices.len(), values.len());
@ -481,7 +506,9 @@ where
let target_lane_count = &mut current_target_major_counts[source_minor_idx];
let entry_offset = target_offsets[source_minor_idx] + *target_lane_count;
target_indices[entry_offset] = source_major_idx;
unsafe { target_values.set(entry_offset, val.inlined_clone()); }
unsafe {
target_values.set(entry_offset, val.inlined_clone());
}
*target_lane_count += 1;
}
}

View File

@ -3,14 +3,14 @@
//! This is the module-level documentation. See [`CscMatrix`] for the main documentation of the
//! CSC implementation.
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
use crate::csr::CsrMatrix;
use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut};
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
use std::slice::{IterMut, Iter};
use num_traits::{One};
use nalgebra::Scalar;
use num_traits::One;
use std::slice::{Iter, IterMut};
/// A CSC representation of a sparse matrix.
///
@ -130,7 +130,7 @@ impl<T> CscMatrix<T> {
/// Create a zero CSC matrix with no explicitly stored entries.
pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self {
cs: CsMatrix::new(ncols, nrows)
cs: CsMatrix::new(ncols, nrows),
}
}
@ -196,7 +196,11 @@ impl<T> CscMatrix<T> {
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
let pattern = SparsityPattern::try_from_offsets_and_indices(
num_cols, num_rows, col_offsets, row_indices)
num_cols,
num_rows,
col_offsets,
row_indices,
)
.map_err(pattern_format_error_to_csc_error)?;
Self::try_from_pattern_and_values(pattern, values)
}
@ -205,16 +209,19 @@ impl<T> CscMatrix<T> {
///
/// Returns an error if the number of values does not match the number of minor indices
/// in the pattern.
pub fn try_from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
-> Result<Self, SparseFormatError> {
pub fn try_from_pattern_and_values(
pattern: SparsityPattern,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() {
Ok(Self {
cs: CsMatrix::from_pattern_and_values(pattern, values)
cs: CsMatrix::from_pattern_and_values(pattern, values),
})
} else {
Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"Number of values and row indices must be the same"))
"Number of values and row indices must be the same",
))
}
}
@ -239,7 +246,7 @@ impl<T> CscMatrix<T> {
pub fn triplet_iter(&self) -> CscTripletIter<T> {
CscTripletIter {
pattern_iter: self.pattern().entries(),
values_iter: self.values().iter()
values_iter: self.values().iter(),
}
}
@ -270,7 +277,7 @@ impl<T> CscMatrix<T> {
let (pattern, values) = self.cs.pattern_and_values_mut();
CscTripletIterMut {
pattern_iter: pattern.entries(),
values_mut_iter: values.iter_mut()
values_mut_iter: values.iter_mut(),
}
}
@ -281,8 +288,7 @@ impl<T> CscMatrix<T> {
/// Panics if column index is out of bounds.
#[inline]
pub fn col(&self, index: usize) -> CscCol<T> {
self.get_col(index)
.expect("Row index must be in bounds")
self.get_col(index).expect("Row index must be in bounds")
}
/// Mutable column access for the given column index.
@ -299,23 +305,19 @@ impl<T> CscMatrix<T> {
/// Return the column at the given column index, or `None` if out of bounds.
#[inline]
pub fn get_col(&self, index: usize) -> Option<CscCol<T>> {
self.cs
.get_lane(index)
.map(|lane| CscCol { lane })
self.cs.get_lane(index).map(|lane| CscCol { lane })
}
/// Mutable column access for the given column index, or `None` if out of bounds.
#[inline]
pub fn get_col_mut(&mut self, index: usize) -> Option<CscColMut<T>> {
self.cs
.get_lane_mut(index)
.map(|lane| CscColMut { lane })
self.cs.get_lane_mut(index).map(|lane| CscColMut { lane })
}
/// An iterator over columns in the matrix.
pub fn col_iter(&self) -> CscColIter<T> {
CscColIter {
lane_iter: CsLaneIter::new(self.pattern(), self.values())
lane_iter: CsLaneIter::new(self.pattern(), self.values()),
}
}
@ -323,7 +325,7 @@ impl<T> CscMatrix<T> {
pub fn col_iter_mut(&mut self) -> CscColIterMut<T> {
let (pattern, values) = self.cs.pattern_and_values_mut();
CscColIterMut {
lane_iter: CsLaneIterMut::new(pattern, values)
lane_iter: CsLaneIterMut::new(pattern, values),
}
}
@ -397,8 +399,11 @@ impl<T> CscMatrix<T> {
///
/// Each call to this function incurs the cost of a binary search among the explicitly
/// stored row entries for the given column.
pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize)
-> Option<SparseEntryMut<T>> {
pub fn get_entry_mut(
&mut self,
row_index: usize,
col_index: usize,
) -> Option<SparseEntryMut<T>> {
self.cs.get_entry_mut(col_index, row_index)
}
@ -444,11 +449,15 @@ impl<T> CscMatrix<T> {
pub fn filter<P>(&self, predicate: P) -> Self
where
T: Clone,
P: Fn(usize, usize, &T) -> bool
P: Fn(usize, usize, &T) -> bool,
{
// Note: Predicate uses (row, col, value), so we have to switch around since
// cs uses (major, minor, value)
Self { cs: self.cs.filter(|col_idx, row_idx, v| predicate(row_idx, col_idx, v)) }
Self {
cs: self
.cs
.filter(|col_idx, row_idx, v| predicate(row_idx, col_idx, v)),
}
}
/// Returns a new matrix representing the upper triangular part of this matrix.
@ -456,7 +465,7 @@ impl<T> CscMatrix<T> {
/// The result includes the diagonal of the matrix.
pub fn upper_triangle(&self) -> Self
where
T: Clone
T: Clone,
{
self.filter(|i, j, _| i <= j)
}
@ -466,7 +475,7 @@ impl<T> CscMatrix<T> {
/// The result includes the diagonal of the matrix.
pub fn lower_triangle(&self) -> Self
where
T: Clone
T: Clone,
{
self.filter(|i, j, _| i >= j)
}
@ -474,15 +483,17 @@ impl<T> CscMatrix<T> {
/// Returns the diagonal of the matrix as a sparse matrix.
pub fn diagonal_as_csc(&self) -> Self
where
T: Clone
T: Clone,
{
Self { cs: self.cs.diagonal_as_matrix() }
Self {
cs: self.cs.diagonal_as_matrix(),
}
}
}
impl<T> CscMatrix<T>
where
T: Scalar
T: Scalar,
{
/// Compute the transpose of the matrix.
pub fn transpose(&self) -> CscMatrix<T> {
@ -495,7 +506,7 @@ impl<T: Scalar + One> CscMatrix<T> {
#[inline]
pub fn identity(n: usize) -> Self {
Self {
cs: CsMatrix::identity(n)
cs: CsMatrix::identity(n),
}
}
}
@ -505,30 +516,34 @@ impl<T: Scalar + One> CscMatrix<T> {
/// This ensures that the terminology is consistent: we are talking about rows and columns,
/// not lanes, major and minor dimensions.
fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseFormatError {
use SparsityPatternFormatError::*;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparseFormatError as E;
use SparseFormatErrorKind as K;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparsityPatternFormatError::*;
match err {
InvalidOffsetArrayLength => E::from_kind_and_msg(
K::InvalidStructure,
"Length of col offset array is not equal to ncols + 1."),
"Length of col offset array is not equal to ncols + 1.",
),
InvalidOffsetFirstLast => E::from_kind_and_msg(
K::InvalidStructure,
"First or last col offset is inconsistent with format specification."),
"First or last col offset is inconsistent with format specification.",
),
NonmonotonicOffsets => E::from_kind_and_msg(
K::InvalidStructure,
"Col offsets are not monotonically increasing."),
"Col offsets are not monotonically increasing.",
),
NonmonotonicMinorIndices => E::from_kind_and_msg(
K::InvalidStructure,
"Row indices are not monotonically increasing (sorted) within each column."),
MinorIndexOutOfBounds => E::from_kind_and_msg(
K::IndexOutOfBounds,
"Row indices are out of bounds."),
PatternDuplicateEntry => E::from_kind_and_msg(
K::DuplicateEntry,
"Matrix data contains duplicate entries."),
"Row indices are not monotonically increasing (sorted) within each column.",
),
MinorIndexOutOfBounds => {
E::from_kind_and_msg(K::IndexOutOfBounds, "Row indices are out of bounds.")
}
PatternDuplicateEntry => {
E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
}
}
}
@ -536,7 +551,7 @@ fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseF
#[derive(Debug)]
pub struct CscTripletIter<'a, T> {
pattern_iter: SparsityPatternIter<'a>,
values_iter: Iter<'a, T>
values_iter: Iter<'a, T>,
}
impl<'a, T: Clone> CscTripletIter<'a, T> {
@ -559,7 +574,7 @@ impl<'a, T> Iterator for CscTripletIter<'a, T> {
match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((j, i, v)),
_ => None
_ => None,
}
}
}
@ -568,7 +583,7 @@ impl<'a, T> Iterator for CscTripletIter<'a, T> {
#[derive(Debug)]
pub struct CscTripletIterMut<'a, T> {
pattern_iter: SparsityPatternIter<'a>,
values_mut_iter: IterMut<'a, T>
values_mut_iter: IterMut<'a, T>,
}
impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
@ -581,7 +596,7 @@ impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((j, i, v)),
_ => None
_ => None,
}
}
}
@ -589,7 +604,7 @@ impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
/// An immutable representation of a column in a CSC matrix.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CscCol<'a, T> {
lane: CsLane<'a, T>
lane: CsLane<'a, T>,
}
/// A mutable representation of a column in a CSC matrix.
@ -598,7 +613,7 @@ pub struct CscCol<'a, T> {
/// to the column cannot be modified.
#[derive(Debug, PartialEq, Eq)]
pub struct CscColMut<'a, T> {
lane: CsLaneMut<'a, T>
lane: CsLaneMut<'a, T>,
}
/// Implement the methods common to both CscCol and CscColMut
@ -637,7 +652,7 @@ macro_rules! impl_csc_col_common_methods {
self.lane.get_entry(global_row_index)
}
}
}
};
}
impl_csc_col_common_methods!(CscCol<'a, T>);
@ -666,33 +681,29 @@ impl<'a, T> CscColMut<'a, T> {
/// Column iterator for [CscMatrix](struct.CscMatrix.html).
pub struct CscColIter<'a, T> {
lane_iter: CsLaneIter<'a, T>
lane_iter: CsLaneIter<'a, T>,
}
impl<'a, T> Iterator for CscColIter<'a, T> {
type Item = CscCol<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
self.lane_iter
.next()
.map(|lane| CscCol { lane })
self.lane_iter.next().map(|lane| CscCol { lane })
}
}
/// Mutable column iterator for [CscMatrix](struct.CscMatrix.html).
pub struct CscColIterMut<'a, T> {
lane_iter: CsLaneIterMut<'a, T>
lane_iter: CsLaneIterMut<'a, T>,
}
impl<'a, T> Iterator for CscColIterMut<'a, T>
where
T: 'a
T: 'a,
{
type Item = CscColMut<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
self.lane_iter
.next()
.map(|lane| CscColMut { lane })
self.lane_iter.next().map(|lane| CscColMut { lane })
}
}

View File

@ -2,15 +2,15 @@
//!
//! This is the module-level documentation. See [`CsrMatrix`] for the main documentation of the
//! CSC implementation.
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
use crate::csc::CscMatrix;
use crate::cs::{CsMatrix, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut};
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
use nalgebra::Scalar;
use num_traits::{One};
use num_traits::One;
use std::slice::{IterMut, Iter};
use std::slice::{Iter, IterMut};
/// A CSR representation of a sparse matrix.
///
@ -130,7 +130,7 @@ impl<T> CsrMatrix<T> {
/// Create a zero CSR matrix with no explicitly stored entries.
pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self {
cs: CsMatrix::new(nrows, ncols)
cs: CsMatrix::new(nrows, ncols),
}
}
@ -198,7 +198,11 @@ impl<T> CsrMatrix<T> {
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
let pattern = SparsityPattern::try_from_offsets_and_indices(
num_rows, num_cols, row_offsets, col_indices)
num_rows,
num_cols,
row_offsets,
col_indices,
)
.map_err(pattern_format_error_to_csr_error)?;
Self::try_from_pattern_and_values(pattern, values)
}
@ -207,16 +211,19 @@ impl<T> CsrMatrix<T> {
///
/// Returns an error if the number of values does not match the number of minor indices
/// in the pattern.
pub fn try_from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
-> Result<Self, SparseFormatError> {
pub fn try_from_pattern_and_values(
pattern: SparsityPattern,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() {
Ok(Self {
cs: CsMatrix::from_pattern_and_values(pattern, values)
cs: CsMatrix::from_pattern_and_values(pattern, values),
})
} else {
Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"Number of values and column indices must be the same"))
"Number of values and column indices must be the same",
))
}
}
@ -241,7 +248,7 @@ impl<T> CsrMatrix<T> {
pub fn triplet_iter(&self) -> CsrTripletIter<T> {
CsrTripletIter {
pattern_iter: self.pattern().entries(),
values_iter: self.values().iter()
values_iter: self.values().iter(),
}
}
@ -272,7 +279,7 @@ impl<T> CsrMatrix<T> {
let (pattern, values) = self.cs.pattern_and_values_mut();
CsrTripletIterMut {
pattern_iter: pattern.entries(),
values_mut_iter: values.iter_mut()
values_mut_iter: values.iter_mut(),
}
}
@ -283,8 +290,7 @@ impl<T> CsrMatrix<T> {
/// Panics if row index is out of bounds.
#[inline]
pub fn row(&self, index: usize) -> CsrRow<T> {
self.get_row(index)
.expect("Row index must be in bounds")
self.get_row(index).expect("Row index must be in bounds")
}
/// Mutable row access for the given row index.
@ -301,23 +307,19 @@ impl<T> CsrMatrix<T> {
/// Return the row at the given row index, or `None` if out of bounds.
#[inline]
pub fn get_row(&self, index: usize) -> Option<CsrRow<T>> {
self.cs
.get_lane(index)
.map(|lane| CsrRow { lane })
self.cs.get_lane(index).map(|lane| CsrRow { lane })
}
/// Mutable row access for the given row index, or `None` if out of bounds.
#[inline]
pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<T>> {
self.cs
.get_lane_mut(index)
.map(|lane| CsrRowMut { lane })
self.cs.get_lane_mut(index).map(|lane| CsrRowMut { lane })
}
/// An iterator over rows in the matrix.
pub fn row_iter(&self) -> CsrRowIter<T> {
CsrRowIter {
lane_iter: CsLaneIter::new(self.pattern(), self.values())
lane_iter: CsLaneIter::new(self.pattern(), self.values()),
}
}
@ -399,8 +401,11 @@ impl<T> CsrMatrix<T> {
///
/// Each call to this function incurs the cost of a binary search among the explicitly
/// stored column entries for the given row.
pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize)
-> Option<SparseEntryMut<T>> {
pub fn get_entry_mut(
&mut self,
row_index: usize,
col_index: usize,
) -> Option<SparseEntryMut<T>> {
self.cs.get_entry_mut(row_index, col_index)
}
@ -446,9 +451,13 @@ impl<T> CsrMatrix<T> {
pub fn filter<P>(&self, predicate: P) -> Self
where
T: Clone,
P: Fn(usize, usize, &T) -> bool
P: Fn(usize, usize, &T) -> bool,
{
Self { cs: self.cs.filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)) }
Self {
cs: self
.cs
.filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)),
}
}
/// Returns a new matrix representing the upper triangular part of this matrix.
@ -456,7 +465,7 @@ impl<T> CsrMatrix<T> {
/// The result includes the diagonal of the matrix.
pub fn upper_triangle(&self) -> Self
where
T: Clone
T: Clone,
{
self.filter(|i, j, _| i <= j)
}
@ -466,7 +475,7 @@ impl<T> CsrMatrix<T> {
/// The result includes the diagonal of the matrix.
pub fn lower_triangle(&self) -> Self
where
T: Clone
T: Clone,
{
self.filter(|i, j, _| i >= j)
}
@ -474,15 +483,17 @@ impl<T> CsrMatrix<T> {
/// Returns the diagonal of the matrix as a sparse matrix.
pub fn diagonal_as_csr(&self) -> Self
where
T: Clone
T: Clone,
{
Self { cs: self.cs.diagonal_as_matrix() }
Self {
cs: self.cs.diagonal_as_matrix(),
}
}
}
impl<T> CsrMatrix<T>
where
T: Scalar
T: Scalar,
{
/// Compute the transpose of the matrix.
pub fn transpose(&self) -> CsrMatrix<T> {
@ -495,7 +506,7 @@ impl<T: Scalar + One> CsrMatrix<T> {
#[inline]
pub fn identity(n: usize) -> Self {
Self {
cs: CsMatrix::identity(n)
cs: CsMatrix::identity(n),
}
}
}
@ -505,30 +516,34 @@ impl<T: Scalar + One> CsrMatrix<T> {
/// This ensures that the terminology is consistent: we are talking about rows and columns,
/// not lanes, major and minor dimensions.
fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseFormatError {
use SparsityPatternFormatError::*;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparseFormatError as E;
use SparseFormatErrorKind as K;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparsityPatternFormatError::*;
match err {
InvalidOffsetArrayLength => E::from_kind_and_msg(
K::InvalidStructure,
"Length of row offset array is not equal to nrows + 1."),
"Length of row offset array is not equal to nrows + 1.",
),
InvalidOffsetFirstLast => E::from_kind_and_msg(
K::InvalidStructure,
"First or last row offset is inconsistent with format specification."),
"First or last row offset is inconsistent with format specification.",
),
NonmonotonicOffsets => E::from_kind_and_msg(
K::InvalidStructure,
"Row offsets are not monotonically increasing."),
"Row offsets are not monotonically increasing.",
),
NonmonotonicMinorIndices => E::from_kind_and_msg(
K::InvalidStructure,
"Column indices are not monotonically increasing (sorted) within each row."),
MinorIndexOutOfBounds => E::from_kind_and_msg(
K::IndexOutOfBounds,
"Column indices are out of bounds."),
PatternDuplicateEntry => E::from_kind_and_msg(
K::DuplicateEntry,
"Matrix data contains duplicate entries."),
"Column indices are not monotonically increasing (sorted) within each row.",
),
MinorIndexOutOfBounds => {
E::from_kind_and_msg(K::IndexOutOfBounds, "Column indices are out of bounds.")
}
PatternDuplicateEntry => {
E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
}
}
}
@ -536,7 +551,7 @@ fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseF
#[derive(Debug)]
pub struct CsrTripletIter<'a, T> {
pattern_iter: SparsityPatternIter<'a>,
values_iter: Iter<'a, T>
values_iter: Iter<'a, T>,
}
impl<'a, T: Clone> CsrTripletIter<'a, T> {
@ -559,7 +574,7 @@ impl<'a, T> Iterator for CsrTripletIter<'a, T> {
match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((i, j, v)),
_ => None
_ => None,
}
}
}
@ -568,7 +583,7 @@ impl<'a, T> Iterator for CsrTripletIter<'a, T> {
#[derive(Debug)]
pub struct CsrTripletIterMut<'a, T> {
pattern_iter: SparsityPatternIter<'a>,
values_mut_iter: IterMut<'a, T>
values_mut_iter: IterMut<'a, T>,
}
impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
@ -581,7 +596,7 @@ impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((i, j, v)),
_ => None
_ => None,
}
}
}
@ -589,7 +604,7 @@ impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
/// An immutable representation of a row in a CSR matrix.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsrRow<'a, T> {
lane: CsLane<'a, T>
lane: CsLane<'a, T>,
}
/// A mutable representation of a row in a CSR matrix.
@ -598,7 +613,7 @@ pub struct CsrRow<'a, T> {
/// to the row cannot be modified.
#[derive(Debug, PartialEq, Eq)]
pub struct CsrRowMut<'a, T> {
lane: CsLaneMut<'a, T>
lane: CsLaneMut<'a, T>,
}
/// Implement the methods common to both CsrRow and CsrRowMut
@ -638,7 +653,7 @@ macro_rules! impl_csr_row_common_methods {
self.lane.get_entry(global_col_index)
}
}
}
};
}
impl_csr_row_common_methods!(CsrRow<'a, T>);
@ -670,33 +685,29 @@ impl<'a, T> CsrRowMut<'a, T> {
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
pub struct CsrRowIter<'a, T> {
lane_iter: CsLaneIter<'a, T>
lane_iter: CsLaneIter<'a, T>,
}
impl<'a, T> Iterator for CsrRowIter<'a, T> {
type Item = CsrRow<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
self.lane_iter
.next()
.map(|lane| CsrRow { lane })
self.lane_iter.next().map(|lane| CsrRow { lane })
}
}
/// Mutable row iterator for [CsrMatrix](struct.CsrMatrix.html).
pub struct CsrRowIterMut<'a, T> {
lane_iter: CsLaneIterMut<'a, T>
lane_iter: CsLaneIterMut<'a, T>,
}
impl<'a, T> Iterator for CsrRowIterMut<'a, T>
where
T: 'a
T: 'a,
{
type Item = CsrRowMut<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
self.lane_iter
.next()
.map(|lane| CsrRowMut { lane })
self.lane_iter.next().map(|lane| CsrRowMut { lane })
}
}

View File

@ -1,10 +1,10 @@
use crate::pattern::SparsityPattern;
use crate::csc::CscMatrix;
use core::{mem, iter};
use nalgebra::{Scalar, RealField, DMatrixSlice, DMatrixSliceMut, DMatrix};
use std::fmt::{Display, Formatter};
use crate::ops::serial::spsolve_csc_lower_triangular;
use crate::ops::Op;
use crate::pattern::SparsityPattern;
use core::{iter, mem};
use nalgebra::{DMatrix, DMatrixSlice, DMatrixSliceMut, RealField, Scalar};
use std::fmt::{Display, Formatter};
/// A symbolic sparse Cholesky factorization of a CSC matrix.
///
@ -15,7 +15,7 @@ pub struct CscSymbolicCholesky {
m_pattern: SparsityPattern,
l_pattern: SparsityPattern,
// u in this context is L^T, so that M = L L^T
u_pattern: SparsityPattern
u_pattern: SparsityPattern,
}
impl CscSymbolicCholesky {
@ -28,8 +28,11 @@ impl CscSymbolicCholesky {
///
/// Panics if the sparsity pattern is not square.
pub fn factor(pattern: SparsityPattern) -> Self {
assert_eq!(pattern.major_dim(), pattern.minor_dim(),
"Major and minor dimensions must be the same (square matrix).");
assert_eq!(
pattern.major_dim(),
pattern.minor_dim(),
"Major and minor dimensions must be the same (square matrix)."
);
let (l_pattern, u_pattern) = nonzero_pattern(&pattern);
Self {
m_pattern: pattern,
@ -65,7 +68,7 @@ pub struct CscCholesky<T> {
l_factor: CscMatrix<T>,
u_pattern: SparsityPattern,
work_x: Vec<T>,
work_c: Vec<usize>
work_c: Vec<usize>,
}
#[derive(Debug, PartialEq, Eq, Clone)]
@ -100,16 +103,20 @@ impl<T: RealField> CscCholesky<T> {
///
/// Panics if the number of values differ from the number of non-zeros of the sparsity pattern
/// of the matrix that was symbolically factored.
pub fn factor_numerical(symbolic: CscSymbolicCholesky, values: &[T])
-> Result<Self, CholeskyError>
{
assert_eq!(symbolic.l_pattern.nnz(), symbolic.u_pattern.nnz(),
"u is just the transpose of l, so should have the same nnz");
pub fn factor_numerical(
symbolic: CscSymbolicCholesky,
values: &[T],
) -> Result<Self, CholeskyError> {
assert_eq!(
symbolic.l_pattern.nnz(),
symbolic.u_pattern.nnz(),
"u is just the transpose of l, so should have the same nnz"
);
let l_nnz = symbolic.l_pattern.nnz();
let l_values = vec![T::zero(); l_nnz];
let l_factor = CscMatrix::try_from_pattern_and_values(symbolic.l_pattern, l_values)
.unwrap();
let l_factor =
CscMatrix::try_from_pattern_and_values(symbolic.l_pattern, l_values).unwrap();
let (nrows, ncols) = (l_factor.nrows(), l_factor.ncols());
@ -229,11 +236,9 @@ impl<T: RealField> CscCholesky<T> {
{
let (offsets, _, values) = self.l_factor.csc_data_mut();
*values
.get_unchecked_mut(*offsets.get_unchecked(k)) = denom;
*values.get_unchecked_mut(*offsets.get_unchecked(k)) = denom;
}
let mut col_k = self.l_factor.col_mut(k);
let (col_k_rows, col_k_values) = col_k.rows_and_values_mut();
let col_k_entries = col_k_rows.iter().zip(col_k_values);
@ -269,19 +274,16 @@ impl<T: RealField> CscCholesky<T> {
/// # Panics
///
/// Panics if `b` is not square.
pub fn solve_mut<'a>(&'a self, b: impl Into<DMatrixSliceMut<'a, T>>)
{
pub fn solve_mut<'a>(&'a self, b: impl Into<DMatrixSliceMut<'a, T>>) {
let expect_msg = "If the Cholesky factorization succeeded,\
then the triangular solve should never fail";
// Solve LY = B
let mut y = b.into();
spsolve_csc_lower_triangular(Op::NoOp(self.l()), &mut y)
.expect(expect_msg);
spsolve_csc_lower_triangular(Op::NoOp(self.l()), &mut y).expect(expect_msg);
// Solve L^T X = Y
let mut x = y;
spsolve_csc_lower_triangular(Op::Transpose(self.l()), &mut x)
.expect(expect_msg);
spsolve_csc_lower_triangular(Op::Transpose(self.l()), &mut x).expect(expect_msg);
}
}
@ -333,8 +335,8 @@ fn nonzero_pattern(m: &SparsityPattern) -> (SparsityPattern, SparsityPattern) {
col_offsets.push(rows.len());
}
let u_pattern = SparsityPattern::try_from_offsets_and_indices(nrows, ncols, col_offsets, rows)
.unwrap();
let u_pattern =
SparsityPattern::try_from_offsets_and_indices(nrows, ncols, col_offsets, rows).unwrap();
// TODO: Avoid this transpose?
let l_pattern = u_pattern.transpose();

View File

@ -135,13 +135,13 @@
#![deny(unused_results)]
#![deny(missing_docs)]
pub mod convert;
pub mod coo;
pub mod csc;
pub mod csr;
pub mod pattern;
pub mod ops;
pub mod convert;
pub mod factorization;
pub mod ops;
pub mod pattern;
pub(crate) mod cs;
@ -151,16 +151,16 @@ pub mod proptest;
#[cfg(feature = "compare")]
mod matrixcompare;
use num_traits::Zero;
use std::error::Error;
use std::fmt;
use num_traits::Zero;
/// Errors produced by functions that expect well-formed sparse format data.
#[derive(Debug)]
pub struct SparseFormatError {
kind: SparseFormatErrorKind,
// Currently we only use an underlying error for generating the `Display` impl
error: Box<dyn Error>
error: Box<dyn Error>,
}
impl SparseFormatError {
@ -170,10 +170,7 @@ impl SparseFormatError {
}
pub(crate) fn from_kind_and_error(kind: SparseFormatErrorKind, error: Box<dyn Error>) -> Self {
Self {
kind,
error
}
Self { kind, error }
}
/// Helper functionality for more conveniently creating errors.
@ -221,7 +218,7 @@ pub enum SparseEntry<'a, T> {
/// is explicitly stored (a so-called "explicit zero").
NonZero(&'a T),
/// The entry is implicitly zero, i.e. it is not explicitly stored.
Zero
Zero,
}
impl<'a, T: Clone + Zero> SparseEntry<'a, T> {
@ -232,7 +229,7 @@ impl<'a, T: Clone + Zero> SparseEntry<'a, T> {
pub fn to_value(self) -> T {
match self {
SparseEntry::NonZero(value) => value.clone(),
SparseEntry::Zero => T::zero()
SparseEntry::Zero => T::zero(),
}
}
}
@ -248,7 +245,7 @@ pub enum SparseEntryMut<'a, T> {
/// is explicitly stored (a so-called "explicit zero").
NonZero(&'a mut T),
/// The entry is implicitly zero i.e. it is not explicitly stored.
Zero
Zero,
}
impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
@ -259,7 +256,7 @@ impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
pub fn to_value(self) -> T {
match self {
SparseEntryMut::NonZero(value) => value.clone(),
SparseEntryMut::Zero => T::zero()
SparseEntryMut::Zero => T::zero(),
}
}
}

View File

@ -1,9 +1,9 @@
//! Implements core traits for use with `matrixcompare`.
use crate::csr::CsrMatrix;
use crate::coo::CooMatrix;
use crate::csc::CscMatrix;
use crate::csr::CsrMatrix;
use matrixcompare_core;
use matrixcompare_core::{Access, SparseAccess};
use crate::coo::CooMatrix;
macro_rules! impl_matrix_for_csr_csc {
($MatrixType:ident) => {
@ -13,7 +13,9 @@ macro_rules! impl_matrix_for_csr_csc {
}
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
self.triplet_iter().map(|(i, j, v)| (i, j, v.clone())).collect()
self.triplet_iter()
.map(|(i, j, v)| (i, j, v.clone()))
.collect()
}
}
@ -30,7 +32,7 @@ macro_rules! impl_matrix_for_csr_csc {
Access::Sparse(self)
}
}
}
};
}
impl_matrix_for_csr_csc!(CsrMatrix);
@ -42,7 +44,9 @@ impl<T: Clone> SparseAccess<T> for CooMatrix<T> {
}
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
self.triplet_iter().map(|(i, j, v)| (i, j, v.clone())).collect()
self.triplet_iter()
.map(|(i, j, v)| (i, j, v.clone()))
.collect()
}
}

View File

@ -1,15 +1,20 @@
use crate::csr::CsrMatrix;
use crate::csc::CscMatrix;
use crate::csr::CsrMatrix;
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg};
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_csr_pattern, spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense, spmm_csc_pattern};
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, MatrixMN, Dim,
Dynamic, DefaultAllocator, U1};
use nalgebra::allocator::{Allocator};
use nalgebra::constraint::{DimEq, ShapeConstraint};
use num_traits::{Zero, One};
use crate::ops::{Op};
use crate::ops::serial::{
spadd_csc_prealloc, spadd_csr_prealloc, spadd_pattern, spmm_csc_dense, spmm_csc_pattern,
spmm_csc_prealloc, spmm_csr_dense, spmm_csr_pattern, spmm_csr_prealloc,
};
use crate::ops::Op;
use nalgebra::allocator::Allocator;
use nalgebra::base::storage::Storage;
use nalgebra::constraint::{DimEq, ShapeConstraint};
use nalgebra::{
ClosedAdd, ClosedDiv, ClosedMul, ClosedSub, DefaultAllocator, Dim, Dynamic, Matrix, MatrixMN,
Scalar, U1,
};
use num_traits::{One, Zero};
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Neg, Sub};
/// Helper macro for implementing binary operators for different matrix types
/// See below for usage.
@ -188,7 +193,7 @@ macro_rules! impl_neg {
($matrix_type:ident) => {
impl<T> Neg for $matrix_type<T>
where
T: Scalar + Neg<Output=T>
T: Scalar + Neg<Output = T>,
{
type Output = $matrix_type<T>;
@ -202,7 +207,7 @@ macro_rules! impl_neg {
impl<'a, T> Neg for &'a $matrix_type<T>
where
T: Scalar + Neg<Output=T>
T: Scalar + Neg<Output = T>,
{
type Output = $matrix_type<T>;
@ -214,7 +219,7 @@ macro_rules! impl_neg {
-self.clone()
}
}
}
};
}
impl_neg!(CsrMatrix);

View File

@ -148,13 +148,14 @@ impl<T> Op<T> {
pub fn as_ref(&self) -> Op<&T> {
match self {
Op::NoOp(obj) => Op::NoOp(&obj),
Op::Transpose(obj) => Op::Transpose(&obj)
Op::Transpose(obj) => Op::Transpose(&obj),
}
}
/// Converts the underlying data type.
pub fn convert<U>(self) -> Op<U>
where T: Into<U>
where
T: Into<U>,
{
self.map_same_op(T::into)
}
@ -163,7 +164,7 @@ impl<T> Op<T> {
pub fn map_same_op<U, F: FnOnce(T) -> U>(self, f: F) -> Op<U> {
match self {
Op::NoOp(obj) => Op::NoOp(f(obj)),
Op::Transpose(obj) => Op::Transpose(f(obj))
Op::Transpose(obj) => Op::Transpose(f(obj)),
}
}
@ -181,7 +182,7 @@ impl<T> Op<T> {
pub fn transposed(self) -> Self {
match self {
Op::NoOp(obj) => Op::Transpose(obj),
Op::Transpose(obj) => Op::NoOp(obj)
Op::Transpose(obj) => Op::NoOp(obj),
}
}
}
@ -191,4 +192,3 @@ impl<T> From<T> for Op<T> {
Self::NoOp(obj)
}
}

View File

@ -1,14 +1,15 @@
use crate::cs::CsMatrix;
use crate::ops::serial::{OperationError, OperationErrorKind};
use crate::ops::Op;
use crate::ops::serial::{OperationErrorKind, OperationError};
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice};
use num_traits::{Zero, One};
use crate::SparseEntryMut;
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
use num_traits::{One, Zero};
fn spmm_cs_unexpected_entry() -> OperationError {
OperationError::from_kind_and_message(
OperationErrorKind::InvalidPattern,
String::from("Found unexpected entry that is not present in `c`."))
String::from("Found unexpected entry that is not present in `c`."),
)
}
/// Helper functionality for implementing CSR/CSC SPMM.
@ -24,10 +25,10 @@ pub fn spmm_cs_prealloc<T>(
c: &mut CsMatrix<T>,
alpha: T,
a: &CsMatrix<T>,
b: &CsMatrix<T>)
-> Result<(), OperationError>
b: &CsMatrix<T>,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
for i in 0..c.pattern().major_dim() {
let a_lane_i = a.get_lane(i).unwrap();
@ -42,7 +43,8 @@ pub fn spmm_cs_prealloc<T>(
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
for (j, b_kj) in b_lane_k.minor_indices().iter().zip(b_lane_k.values()) {
// Determine the location in C to append the value
let (c_local_idx, _) = c_lane_i_cols.iter()
let (c_local_idx, _) = c_lane_i_cols
.iter()
.enumerate()
.find(|(_, c_col)| *c_col == j)
.ok_or_else(spmm_cs_unexpected_entry)?;
@ -60,17 +62,19 @@ pub fn spmm_cs_prealloc<T>(
fn spadd_cs_unexpected_entry() -> OperationError {
OperationError::from_kind_and_message(
OperationErrorKind::InvalidPattern,
String::from("Found entry in `op(a)` that is not present in `c`."))
String::from("Found entry in `op(a)` that is not present in `c`."),
)
}
/// Helper functionality for implementing CSR/CSC SPADD.
pub fn spadd_cs_prealloc<T>(beta: T,
pub fn spadd_cs_prealloc<T>(
beta: T,
c: &mut CsMatrix<T>,
alpha: T,
a: Op<&CsMatrix<T>>)
-> Result<(), OperationError>
a: Op<&CsMatrix<T>>,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
match a {
Op::NoOp(a) => {
@ -88,7 +92,8 @@ pub fn spadd_cs_prealloc<T>(beta: T,
// TODO: Use exponential search instead of linear search.
// If C has substantially more entries in the row than A, then a line search
// will needlessly visit many entries in C.
let (c_idx, _) = c_minors.iter()
let (c_idx, _) = c_minors
.iter()
.enumerate()
.find(|(_, c_col)| *c_col == a_col)
.ok_or_else(spadd_cs_unexpected_entry)?;
@ -110,7 +115,7 @@ pub fn spadd_cs_prealloc<T>(beta: T,
let a_val = a_val.inlined_clone();
let alpha = alpha.inlined_clone();
match c.get_entry_mut(j, i).unwrap() {
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
SparseEntryMut::NonZero(c_ji) => *c_ji += alpha * a_val,
SparseEntryMut::Zero => return Err(spadd_cs_unexpected_entry()),
}
}
@ -124,13 +129,14 @@ pub fn spadd_cs_prealloc<T>(beta: T,
///
/// The implementation essentially assumes that `a` is a CSR matrix. To use it with CSC matrices,
/// the transposed operation must be specified for the CSC matrix.
pub fn spmm_cs_dense<T>(beta: T,
pub fn spmm_cs_dense<T>(
beta: T,
mut c: DMatrixSliceMut<T>,
alpha: T,
a: Op<&CsMatrix<T>>,
b: Op<DMatrixSlice<T>>)
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
b: Op<DMatrixSlice<T>>,
) where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
match a {
Op::NoOp(a) => {
@ -139,17 +145,17 @@ pub fn spmm_cs_dense<T>(beta: T,
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.lane_iter()) {
let mut dot_ij = T::zero();
for (&k, a_ik) in a_row_i.minor_indices().iter().zip(a_row_i.values()) {
let b_contrib =
match b {
let b_contrib = match b {
Op::NoOp(ref b) => b.index((k, j)),
Op::Transpose(ref b) => b.index((j, k))
Op::Transpose(ref b) => b.index((j, k)),
};
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
}
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;
*c_ij = beta.inlined_clone() * c_ij.inlined_clone()
+ alpha.inlined_clone() * dot_ij;
}
}
}
},
Op::Transpose(a) => {
// In this case, we have to pre-multiply C by beta
c *= beta;
@ -165,17 +171,16 @@ pub fn spmm_cs_dense<T>(beta: T,
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
}
},
}
Op::Transpose(ref b) => {
let b_col_k = b.column(k);
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
}
},
}
}
}
},
}
}
}
}

View File

@ -1,9 +1,9 @@
use crate::csc::CscMatrix;
use crate::ops::Op;
use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc};
use crate::ops::serial::cs::{spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc};
use crate::ops::serial::{OperationError, OperationErrorKind};
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice, RealField};
use num_traits::{Zero, One};
use crate::ops::Op;
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, RealField, Scalar};
use num_traits::{One, Zero};
use std::borrow::Cow;
@ -12,25 +12,27 @@ use std::borrow::Cow;
/// # Panics
///
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
pub fn spmm_csc_dense<'a, T>(beta: T,
pub fn spmm_csc_dense<'a, T>(
beta: T,
c: impl Into<DMatrixSliceMut<'a, T>>,
alpha: T,
a: Op<&CscMatrix<T>>,
b: Op<impl Into<DMatrixSlice<'a, T>>>)
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
b: Op<impl Into<DMatrixSlice<'a, T>>>,
) where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
let b = b.convert();
spmm_csc_dense_(beta, c.into(), alpha, a, b)
}
fn spmm_csc_dense_<T>(beta: T,
fn spmm_csc_dense_<T>(
beta: T,
c: DMatrixSliceMut<T>,
alpha: T,
a: Op<&CscMatrix<T>>,
b: Op<DMatrixSlice<T>>)
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
b: Op<DMatrixSlice<T>>,
) where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
assert_compatible_spmm_dims!(c, a, b);
// Need to interpret matrix as transposed since the spmm_cs_dense function assumes CSR layout
@ -46,19 +48,19 @@ fn spmm_csc_dense_<T>(beta: T,
/// # Panics
///
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
pub fn spadd_csc_prealloc<T>(beta: T,
pub fn spadd_csc_prealloc<T>(
beta: T,
c: &mut CscMatrix<T>,
alpha: T,
a: Op<&CscMatrix<T>>)
-> Result<(), OperationError>
a: Op<&CscMatrix<T>>,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
assert_compatible_spadd_dims!(c, a);
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
}
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
///
/// # Errors
@ -74,10 +76,10 @@ pub fn spmm_csc_prealloc<T>(
c: &mut CscMatrix<T>,
alpha: T,
a: Op<&CscMatrix<T>>,
b: Op<&CscMatrix<T>>)
-> Result<(), OperationError>
b: Op<&CscMatrix<T>>,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
assert_compatible_spmm_dims!(c, a, b);
@ -87,7 +89,7 @@ pub fn spmm_csc_prealloc<T>(
(NoOp(ref a), NoOp(ref b)) => {
// Note: We have to reverse the order for CSC matrices
spmm_cs_prealloc(beta, &mut c.cs, alpha, &b.cs, &a.cs)
},
}
_ => {
// Currently we handle transposition by explicitly precomputing transposed matrices
// and calling the operation again without transposition
@ -99,7 +101,9 @@ pub fn spmm_csc_prealloc<T>(
(NoOp(_), NoOp(_)) => unreachable!(),
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
(Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose()))
(Transpose(ref a), Transpose(ref b)) => {
(Owned(a.transpose()), Owned(b.transpose()))
}
}
};
@ -121,13 +125,20 @@ pub fn spmm_csc_prealloc<T>(
/// Panics if `L` is not square, or if `L` and `B` are not dimensionally compatible.
pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
l: Op<&CscMatrix<T>>,
b: impl Into<DMatrixSliceMut<'a, T>>)
-> Result<(), OperationError>
{
b: impl Into<DMatrixSliceMut<'a, T>>,
) -> Result<(), OperationError> {
let b = b.into();
let l_matrix = l.into_inner();
assert_eq!(l_matrix.nrows(), l_matrix.ncols(), "Matrix must be square for triangular solve.");
assert_eq!(l_matrix.nrows(), b.nrows(), "Dimension mismatch in sparse lower triangular solver.");
assert_eq!(
l_matrix.nrows(),
l_matrix.ncols(),
"Matrix must be square for triangular solve."
);
assert_eq!(
l_matrix.nrows(),
b.nrows(),
"Dimension mismatch in sparse lower triangular solver."
);
match l {
Op::NoOp(a) => spsolve_csc_lower_triangular_no_transpose(a, b),
Op::Transpose(a) => spsolve_csc_lower_triangular_transpose(a, b),
@ -136,9 +147,8 @@ pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
l: &CscMatrix<T>,
b: DMatrixSliceMut<T>)
-> Result<(), OperationError>
{
b: DMatrixSliceMut<T>,
) -> Result<(), OperationError> {
let mut x = b;
// Solve column-by-column
@ -187,14 +197,16 @@ fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
fn spsolve_encountered_zero_diagonal() -> Result<(), OperationError> {
let message = "Matrix contains at least one diagonal entry that is zero.";
Err(OperationError::from_kind_and_message(OperationErrorKind::Singular, String::from(message)))
Err(OperationError::from_kind_and_message(
OperationErrorKind::Singular,
String::from(message),
))
}
fn spsolve_csc_lower_triangular_transpose<T: RealField>(
l: &CscMatrix<T>,
b: DMatrixSliceMut<T>)
-> Result<(), OperationError>
{
b: DMatrixSliceMut<T>,
) -> Result<(), OperationError> {
let mut x = b;
// Solve column-by-column

View File

@ -1,31 +1,33 @@
use crate::csr::CsrMatrix;
use crate::ops::{Op};
use crate::ops::serial::{OperationError};
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
use num_traits::{Zero, One};
use crate::ops::serial::cs::{spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc};
use crate::ops::serial::OperationError;
use crate::ops::Op;
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
use num_traits::{One, Zero};
use std::borrow::Cow;
use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc};
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
pub fn spmm_csr_dense<'a, T>(beta: T,
pub fn spmm_csr_dense<'a, T>(
beta: T,
c: impl Into<DMatrixSliceMut<'a, T>>,
alpha: T,
a: Op<&CsrMatrix<T>>,
b: Op<impl Into<DMatrixSlice<'a, T>>>)
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
b: Op<impl Into<DMatrixSlice<'a, T>>>,
) where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
let b = b.convert();
spmm_csr_dense_(beta, c.into(), alpha, a, b)
}
fn spmm_csr_dense_<T>(beta: T,
fn spmm_csr_dense_<T>(
beta: T,
c: DMatrixSliceMut<T>,
alpha: T,
a: Op<&CsrMatrix<T>>,
b: Op<DMatrixSlice<T>>)
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
b: Op<DMatrixSlice<T>>,
) where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
assert_compatible_spmm_dims!(c, a, b);
spmm_cs_dense(beta, c, alpha, a.map_same_op(|a| &a.cs), b)
@ -41,13 +43,14 @@ where
/// # Panics
///
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
pub fn spadd_csr_prealloc<T>(beta: T,
pub fn spadd_csr_prealloc<T>(
beta: T,
c: &mut CsrMatrix<T>,
alpha: T,
a: Op<&CsrMatrix<T>>)
-> Result<(), OperationError>
a: Op<&CsrMatrix<T>>,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
assert_compatible_spadd_dims!(c, a);
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
@ -67,19 +70,17 @@ pub fn spmm_csr_prealloc<T>(
c: &mut CsrMatrix<T>,
alpha: T,
a: Op<&CsrMatrix<T>>,
b: Op<&CsrMatrix<T>>)
-> Result<(), OperationError>
b: Op<&CsrMatrix<T>>,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
assert_compatible_spmm_dims!(c, a, b);
use Op::{NoOp, Transpose};
match (&a, &b) {
(NoOp(ref a), NoOp(ref b)) => {
spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs)
},
(NoOp(ref a), NoOp(ref b)) => spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs),
_ => {
// Currently we handle transposition by explicitly precomputing transposed matrices
// and calling the operation again without transposition
@ -93,7 +94,9 @@ where
(NoOp(_), NoOp(_)) => unreachable!(),
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
(Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose()))
(Transpose(ref a), Transpose(ref b)) => {
(Owned(a.transpose()), Owned(b.transpose()))
}
}
};
@ -101,4 +104,3 @@ where
}
}
}

View File

@ -10,33 +10,31 @@
#[macro_use]
macro_rules! assert_compatible_spmm_dims {
($c:expr, $a:expr, $b:expr) => {
{
($c:expr, $a:expr, $b:expr) => {{
use crate::ops::Op::{NoOp, Transpose};
match (&$a, &$b) {
(NoOp(ref a), NoOp(ref b)) => {
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
assert_eq!(a.ncols(), b.nrows(), "A.ncols() != B.nrows()");
},
}
(Transpose(ref a), NoOp(ref b)) => {
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
assert_eq!(a.nrows(), b.nrows(), "A.nrows() != B.nrows()");
},
}
(NoOp(ref a), Transpose(ref b)) => {
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
assert_eq!(a.ncols(), b.ncols(), "A.ncols() != B.ncols()");
},
}
(Transpose(ref a), Transpose(ref b)) => {
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
assert_eq!(a.nrows(), b.ncols(), "A.nrows() != B.ncols()");
}
}
}
}
}};
}
#[macro_use]
@ -47,32 +45,31 @@ macro_rules! assert_compatible_spadd_dims {
Op::NoOp(a) => {
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
assert_eq!($c.ncols(), a.ncols(), "C.ncols() != A.ncols()");
},
}
Op::Transpose(a) => {
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
assert_eq!($c.ncols(), a.nrows(), "C.ncols() != A.nrows()");
}
}
}
};
}
mod cs;
mod csc;
mod csr;
mod pattern;
mod cs;
pub use csc::*;
pub use csr::*;
pub use pattern::*;
use std::fmt::Formatter;
use std::fmt;
use std::fmt::Formatter;
/// A description of the error that occurred during an arithmetic operation.
#[derive(Clone, Debug)]
pub struct OperationError {
error_kind: OperationErrorKind,
message: String
message: String,
}
/// The different kinds of operation errors that may occur.
@ -92,7 +89,10 @@ pub enum OperationErrorKind {
impl OperationError {
fn from_kind_and_message(error_type: OperationErrorKind, message: String) -> Self {
Self { error_kind: error_type, message }
Self {
error_kind: error_type,
message,
}
}
/// The operation error kind.
@ -110,8 +110,12 @@ impl fmt::Display for OperationError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Sparse matrix operation error: ")?;
match self.kind() {
OperationErrorKind::InvalidPattern => { write!(f, "InvalidPattern")?; }
OperationErrorKind::Singular => { write!(f, "Singular")?; }
OperationErrorKind::InvalidPattern => {
write!(f, "InvalidPattern")?;
}
OperationErrorKind::Singular => {
write!(f, "Singular")?;
}
}
write!(f, " Message: {}", self.message)
}

View File

@ -12,11 +12,17 @@ use std::iter;
/// # Panics
///
/// Panics if the patterns do not have the same major and minor dimensions.
pub fn spadd_pattern(a: &SparsityPattern,
b: &SparsityPattern) -> SparsityPattern
{
assert_eq!(a.major_dim(), b.major_dim(), "Patterns must have identical major dimensions.");
assert_eq!(a.minor_dim(), b.minor_dim(), "Patterns must have identical minor dimensions.");
pub fn spadd_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
assert_eq!(
a.major_dim(),
b.major_dim(),
"Patterns must have identical major dimensions."
);
assert_eq!(
a.minor_dim(),
b.minor_dim(),
"Patterns must have identical minor dimensions."
);
let mut offsets = Vec::new();
let mut indices = Vec::new();
@ -33,8 +39,7 @@ pub fn spadd_pattern(a: &SparsityPattern,
}
// TODO: Consider circumventing format checks? (requires unsafe, should benchmark first)
SparsityPattern::try_from_offsets_and_indices(
a.major_dim(), a.minor_dim(), offsets, indices)
SparsityPattern::try_from_offsets_and_indices(a.major_dim(), a.minor_dim(), offsets, indices)
.expect("Internal error: Pattern must be valid by definition")
}
@ -66,7 +71,11 @@ pub fn spmm_csc_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPat
/// Panics if the patterns, when interpreted as CSR patterns, are not compatible for
/// matrix multiplication.
pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
assert_eq!(a.minor_dim(), b.major_dim(), "a and b must have compatible dimensions");
assert_eq!(
a.minor_dim(),
b.major_dim(),
"a and b must have compatible dimensions"
);
let mut offsets = Vec::new();
let mut indices = Vec::new();
@ -110,8 +119,10 @@ pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPat
/// Iterate over the union of the two sets represented by sorted slices
/// (with unique elements)
fn iterate_union<'a>(mut sorted_a: &'a [usize],
mut sorted_b: &'a [usize]) -> impl Iterator<Item=usize> + 'a {
fn iterate_union<'a>(
mut sorted_a: &'a [usize],
mut sorted_b: &'a [usize],
) -> impl Iterator<Item = usize> + 'a {
iter::from_fn(move || {
if let (Some(a_item), Some(b_item)) = (sorted_a.first(), sorted_b.first()) {
let item = if a_item < b_item {

View File

@ -1,8 +1,8 @@
//! Sparsity patterns for CSR and CSC matrices.
use crate::SparseFormatError;
use std::fmt;
use std::error::Error;
use crate::cs::transpose_cs;
use crate::SparseFormatError;
use std::error::Error;
use std::fmt;
/// A representation of the sparsity pattern of a CSR or CSC matrix.
///
@ -236,12 +236,15 @@ impl SparsityPattern {
self.minor_dim(),
self.major_offsets(),
self.minor_indices(),
&values);
&values,
);
// TODO: Skip checks
Self::try_from_offsets_and_indices(self.minor_dim(),
Self::try_from_offsets_and_indices(
self.minor_dim(),
self.major_dim(),
new_offsets,
new_indices)
new_indices,
)
.expect("Internal error: Transpose should never fail.")
}
}
@ -275,22 +278,25 @@ pub enum SparsityPatternFormatError {
impl From<SparsityPatternFormatError> for SparseFormatError {
fn from(err: SparsityPatternFormatError) -> Self {
use SparsityPatternFormatError::*;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use crate::SparseFormatErrorKind;
use crate::SparseFormatErrorKind::*;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparsityPatternFormatError::*;
match err {
InvalidOffsetArrayLength
| InvalidOffsetFirstLast
| NonmonotonicOffsets
| NonmonotonicMinorIndices
=> SparseFormatError::from_kind_and_error(InvalidStructure, Box::from(err)),
MinorIndexOutOfBounds
=> SparseFormatError::from_kind_and_error(IndexOutOfBounds,
Box::from(err)),
PatternDuplicateEntry
=> SparseFormatError::from_kind_and_error(SparseFormatErrorKind::DuplicateEntry,
Box::from(err)),
| NonmonotonicMinorIndices => {
SparseFormatError::from_kind_and_error(InvalidStructure, Box::from(err))
}
MinorIndexOutOfBounds => {
SparseFormatError::from_kind_and_error(IndexOutOfBounds, Box::from(err))
}
PatternDuplicateEntry => SparseFormatError::from_kind_and_error(
#[allow(unused_qualifications)]
SparseFormatErrorKind::DuplicateEntry,
Box::from(err),
),
}
}
}
@ -300,22 +306,25 @@ impl fmt::Display for SparsityPatternFormatError {
match self {
SparsityPatternFormatError::InvalidOffsetArrayLength => {
write!(f, "Length of offset array is not equal to (major_dim + 1).")
},
}
SparsityPatternFormatError::InvalidOffsetFirstLast => {
write!(f, "First or last offset is incompatible with format.")
},
}
SparsityPatternFormatError::NonmonotonicOffsets => {
write!(f, "Offsets are not monotonically increasing.")
},
}
SparsityPatternFormatError::MinorIndexOutOfBounds => {
write!(f, "A minor index is out of bounds.")
},
}
SparsityPatternFormatError::DuplicateEntry => {
write!(f, "Input data contains duplicate entries.")
},
}
SparsityPatternFormatError::NonmonotonicMinorIndices => {
write!(f, "Minor indices are not monotonically increasing within each lane.")
},
write!(
f,
"Minor indices are not monotonically increasing within each lane."
)
}
}
}
}
@ -340,7 +349,7 @@ impl<'a> SparsityPatternIter<'a> {
major_offsets: pattern.major_offsets(),
minor_indices: pattern.minor_indices(),
current_lane_idx: 0,
remaining_minors_in_lane: minors_in_first_lane
remaining_minors_in_lane: minors_in_first_lane,
}
}
}
@ -375,11 +384,10 @@ impl<'a> Iterator for SparsityPatternIter<'a> {
let upper = self.major_offsets[self.current_lane_idx + 1];
if upper > lower {
self.remaining_minors_in_lane = &self.minor_indices[(lower + 1)..upper];
return Some((self.current_lane_idx, self.minor_indices[lower]))
return Some((self.current_lane_idx, self.minor_indices[lower]));
}
}
}
}
}
}

View File

@ -11,20 +11,22 @@
mod proptest_patched;
use crate::coo::CooMatrix;
use proptest::prelude::*;
use proptest::collection::{vec, hash_map, btree_set};
use nalgebra::{Scalar, Dim};
use std::cmp::min;
use std::iter::{repeat};
use proptest::sample::{Index};
use crate::csc::CscMatrix;
use crate::csr::CsrMatrix;
use crate::pattern::SparsityPattern;
use crate::csc::CscMatrix;
use nalgebra::proptest::DimRange;
use nalgebra::{Dim, Scalar};
use proptest::collection::{btree_set, hash_map, vec};
use proptest::prelude::*;
use proptest::sample::Index;
use std::cmp::min;
use std::iter::repeat;
fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
-> impl Strategy<Value=Vec<(usize, usize)>>
{
fn dense_row_major_coord_strategy(
nrows: usize,
ncols: usize,
nnz: usize,
) -> impl Strategy<Value = Vec<(usize, usize)>> {
assert!(nnz <= nrows * ncols);
let mut booleans = vec![true; nnz];
booleans.append(&mut vec![false; (nrows * ncols) - nnz]);
@ -38,8 +40,7 @@ fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
// // Need to shuffle to make sure they are randomly distributed
// .prop_shuffle()
proptest_patched::Shuffle(Just(booleans))
.prop_map(move |booleans| {
proptest_patched::Shuffle(Just(booleans)).prop_map(move |booleans| {
booleans
.into_iter()
.enumerate()
@ -60,11 +61,12 @@ fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
/// A strategy for generating `nnz` triplets.
///
/// This strategy should generally only be used when `nnz` is close to `nrows * ncols`.
fn dense_triplet_strategy<T>(value_strategy: T,
fn dense_triplet_strategy<T>(
value_strategy: T,
nrows: usize,
ncols: usize,
nnz: usize)
-> impl Strategy<Value=Vec<(usize, usize, T::Value)>>
nnz: usize,
) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
where
T: Strategy + Clone + 'static,
T::Value: Scalar,
@ -100,13 +102,12 @@ where
})
// Assign values to each coordinate pair in order to generate a list of triplets
.prop_flat_map(move |coords| {
vec![value_strategy.clone(); coords.len()]
.prop_map(move |values| {
coords.clone().into_iter()
vec![value_strategy.clone(); coords.len()].prop_map(move |values| {
coords
.clone()
.into_iter()
.zip(values)
.map(|((i, j), v)| {
(i, j, v)
})
.map(|((i, j), v)| (i, j, v))
.collect::<Vec<_>>()
})
})
@ -116,11 +117,12 @@ where
///
/// This strategy should generally only be used when `nnz << nrows * ncols`. If `nnz` is too
/// close to `nrows * ncols` it may fail due to excessive rejected samples.
fn sparse_triplet_strategy<T>(value_strategy: T,
fn sparse_triplet_strategy<T>(
value_strategy: T,
nrows: usize,
ncols: usize,
nnz: usize)
-> impl Strategy<Value=Vec<(usize, usize, T::Value)>>
nnz: usize,
) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
where
T: Strategy + Clone + 'static,
T::Value: Scalar,
@ -131,10 +133,7 @@ fn sparse_triplet_strategy<T>(value_strategy: T,
let coord_strategy = (row_index_strategy, col_index_strategy);
hash_map(coord_strategy, value_strategy.clone(), nnz)
.prop_map(|hash_map| {
let triplets: Vec<_> = hash_map
.into_iter()
.map(|((i, j), v)| (i, j, v))
.collect();
let triplets: Vec<_> = hash_map.into_iter().map(|((i, j), v)| (i, j, v)).collect();
triplets
})
// Although order in the hash map is unspecified, it's not necessarily *random*
@ -153,18 +152,23 @@ pub fn coo_no_duplicates<T>(
value_strategy: T,
rows: impl Into<DimRange>,
cols: impl Into<DimRange>,
max_nonzeros: usize) -> impl Strategy<Value=CooMatrix<T::Value>>
max_nonzeros: usize,
) -> impl Strategy<Value = CooMatrix<T::Value>>
where
T: Strategy + Clone + 'static,
T::Value: Scalar,
{
(rows.into().to_range_inclusive(), cols.into().to_range_inclusive())
(
rows.into().to_range_inclusive(),
cols.into().to_range_inclusive(),
)
.prop_flat_map(move |(nrows, ncols)| {
let max_nonzeros = min(max_nonzeros, nrows * ncols);
let size_range = 0..=max_nonzeros;
let value_strategy = value_strategy.clone();
size_range.prop_flat_map(move |nnz| {
size_range
.prop_flat_map(move |nnz| {
let value_strategy = value_strategy.clone();
if nnz as f64 > 0.10 * (nrows as f64) * (ncols as f64) {
// If the number of nnz is sufficiently dense, then use the dense
@ -202,8 +206,8 @@ pub fn coo_with_duplicates<T>(
rows: impl Into<DimRange>,
cols: impl Into<DimRange>,
max_nonzeros: usize,
max_duplicates: usize)
-> impl Strategy<Value=CooMatrix<T::Value>>
max_duplicates: usize,
) -> impl Strategy<Value = CooMatrix<T::Value>>
where
T: Strategy + Clone + 'static,
T::Value: Scalar,
@ -212,7 +216,8 @@ where
let duplicate_strategy = vec((any::<Index>(), value_strategy.clone()), 0..=max_duplicates);
(coo_strategy, duplicate_strategy)
.prop_flat_map(|(coo, duplicates)| {
let mut triplets: Vec<(usize, usize, T::Value)> = coo.triplet_iter()
let mut triplets: Vec<(usize, usize, T::Value)> = coo
.triplet_iter()
.map(|(i, j, v)| (i, j, v.clone()))
.collect();
if !triplets.is_empty() {
@ -238,7 +243,11 @@ where
})
}
fn sparsity_pattern_from_row_major_coords<I>(nmajor: usize, nminor: usize, coords: I) -> SparsityPattern
fn sparsity_pattern_from_row_major_coords<I>(
nmajor: usize,
nminor: usize,
coords: I,
) -> SparsityPattern
where
I: Iterator<Item = (usize, usize)> + ExactSizeIterator,
{
@ -248,7 +257,10 @@ where
offsets.push(0);
for (idx, (i, j)) in coords.enumerate() {
assert!(i >= current_major);
assert!(i < nmajor && j < nminor, "Generated coords are out of bounds");
assert!(
i < nmajor && j < nminor,
"Generated coords are out of bounds"
);
while current_major < i {
offsets.push(idx);
current_major += 1;
@ -264,10 +276,7 @@ where
assert_eq!(offsets.first().unwrap(), &0);
assert_eq!(offsets.len(), nmajor + 1);
SparsityPattern::try_from_offsets_and_indices(nmajor,
nminor,
offsets,
minors)
SparsityPattern::try_from_offsets_and_indices(nmajor, nminor, offsets, minors)
.expect("Internal error: Generated sparsity pattern is invalid")
}
@ -275,14 +284,17 @@ where
pub fn sparsity_pattern(
major_lanes: impl Into<DimRange>,
minor_lanes: impl Into<DimRange>,
max_nonzeros: usize)
-> impl Strategy<Value=SparsityPattern>
{
(major_lanes.into().to_range_inclusive(), minor_lanes.into().to_range_inclusive())
max_nonzeros: usize,
) -> impl Strategy<Value = SparsityPattern> {
(
major_lanes.into().to_range_inclusive(),
minor_lanes.into().to_range_inclusive(),
)
.prop_flat_map(move |(nmajor, nminor)| {
let max_nonzeros = min(nmajor * nminor, max_nonzeros);
(Just(nmajor), Just(nminor), 0..=max_nonzeros)
}).prop_flat_map(move |(nmajor, nminor, nnz)| {
})
.prop_flat_map(move |(nmajor, nminor, nnz)| {
if 10 * nnz < nmajor * nminor {
// If nnz is small compared to a dense matrix, then use a sparse sampling strategy
btree_set((0..nmajor, 0..nminor), nnz)
@ -297,24 +309,30 @@ pub fn sparsity_pattern(
.prop_map(move |coords| {
let coords = coords.into_iter();
sparsity_pattern_from_row_major_coords(nmajor, nminor, coords)
}).boxed()
})
.boxed()
}
})
}
/// A strategy for generating CSR matrices.
pub fn csr<T>(value_strategy: T,
pub fn csr<T>(
value_strategy: T,
rows: impl Into<DimRange>,
cols: impl Into<DimRange>,
max_nonzeros: usize)
-> impl Strategy<Value=CsrMatrix<T::Value>>
max_nonzeros: usize,
) -> impl Strategy<Value = CsrMatrix<T::Value>>
where
T: Strategy + Clone + 'static,
T::Value: Scalar,
{
let rows = rows.into();
let cols = cols.into();
sparsity_pattern(rows.lower_bound().value() ..= rows.upper_bound().value(), cols.lower_bound().value() ..= cols.upper_bound().value(), max_nonzeros)
sparsity_pattern(
rows.lower_bound().value()..=rows.upper_bound().value(),
cols.lower_bound().value()..=cols.upper_bound().value(),
max_nonzeros,
)
.prop_flat_map(move |pattern| {
let nnz = pattern.nnz();
let values = vec![value_strategy.clone(); nnz];
@ -327,18 +345,23 @@ where
}
/// A strategy for generating CSC matrices.
pub fn csc<T>(value_strategy: T,
pub fn csc<T>(
value_strategy: T,
rows: impl Into<DimRange>,
cols: impl Into<DimRange>,
max_nonzeros: usize)
-> impl Strategy<Value=CscMatrix<T::Value>>
max_nonzeros: usize,
) -> impl Strategy<Value = CscMatrix<T::Value>>
where
T: Strategy + Clone + 'static,
T::Value: Scalar,
{
let rows = rows.into();
let cols = cols.into();
sparsity_pattern(cols.lower_bound().value() ..= cols.upper_bound().value(), rows.lower_bound().value() ..= rows.upper_bound().value(), max_nonzeros)
sparsity_pattern(
cols.lower_bound().value()..=cols.upper_bound().value(),
rows.lower_bound().value()..=rows.upper_bound().value(),
max_nonzeros,
)
.prop_flat_map(move |pattern| {
let nnz = pattern.nnz();
let values = vec![value_strategy.clone(); nnz];

View File

@ -22,11 +22,11 @@
*/
use proptest::strategy::{Strategy, Shuffleable, NewTree, ValueTree};
use proptest::test_runner::{TestRunner, TestRng};
use std::cell::Cell;
use proptest::num;
use proptest::prelude::Rng;
use proptest::strategy::{NewTree, Shuffleable, Strategy, ValueTree};
use proptest::test_runner::{TestRng, TestRunner};
use std::cell::Cell;
#[derive(Clone, Debug)]
#[must_use = "strategies do nothing unless used"]

View File

@ -1,15 +1,15 @@
use proptest::strategy::Strategy;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::proptest::{csr, csc};
use nalgebra_sparse::csc::CscMatrix;
use std::ops::RangeInclusive;
use std::convert::{TryFrom};
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::proptest::{csc, csr};
use proptest::strategy::Strategy;
use std::convert::TryFrom;
use std::fmt::Debug;
use std::ops::RangeInclusive;
#[macro_export]
macro_rules! assert_panics {
($e:expr) => {{
use std::panic::{catch_unwind};
use std::panic::catch_unwind;
use std::stringify;
let expr_string = stringify!($e);
@ -22,7 +22,10 @@ macro_rules! assert_panics {
let result = catch_unwind(|| $e);
if result.is_ok() {
panic!("assert_panics!({}) failed: the expression did not panic.", expr_string);
panic!(
"assert_panics!({}) failed: the expression did not panic.",
expr_string
);
}
}};
}
@ -34,14 +37,20 @@ pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5 ..= 5;
pub fn value_strategy<T>() -> RangeInclusive<T>
where
T: TryFrom<i32>,
T::Error: Debug
T::Error: Debug,
{
let (start, end) = (PROPTEST_I32_VALUE_STRATEGY.start(), PROPTEST_I32_VALUE_STRATEGY.end());
let (start, end) = (
PROPTEST_I32_VALUE_STRATEGY.start(),
PROPTEST_I32_VALUE_STRATEGY.end(),
);
T::try_from(*start).unwrap()..=T::try_from(*end).unwrap()
}
pub fn non_zero_i32_value_strategy() -> impl Strategy<Value = i32> {
let (start, end) = (PROPTEST_I32_VALUE_STRATEGY.start(), PROPTEST_I32_VALUE_STRATEGY.end());
let (start, end) = (
PROPTEST_I32_VALUE_STRATEGY.start(),
PROPTEST_I32_VALUE_STRATEGY.end(),
);
assert!(start < &0);
assert!(end > &0);
// Note: we don't use RangeInclusive for the second range, because then we'd have different
@ -50,9 +59,19 @@ pub fn non_zero_i32_value_strategy() -> impl Strategy<Value=i32> {
}
pub fn csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
csr(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
csr(
PROPTEST_I32_VALUE_STRATEGY,
PROPTEST_MATRIX_DIM,
PROPTEST_MATRIX_DIM,
PROPTEST_MAX_NNZ,
)
}
pub fn csc_strategy() -> impl Strategy<Value = CscMatrix<i32>> {
csc(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
csc(
PROPTEST_I32_VALUE_STRATEGY,
PROPTEST_MATRIX_DIM,
PROPTEST_MATRIX_DIM,
PROPTEST_MAX_NNZ,
)
}

View File

@ -1,17 +1,16 @@
use nalgebra_sparse::coo::CooMatrix;
use nalgebra_sparse::convert::serial::{convert_coo_dense, convert_coo_csr,
convert_dense_coo, convert_csr_dense,
convert_csr_coo, convert_dense_csr,
convert_csc_coo, convert_coo_csc,
convert_csc_dense, convert_dense_csc,
convert_csr_csc, convert_csc_csr};
use nalgebra_sparse::proptest::{coo_with_duplicates, coo_no_duplicates, csr, csc};
use nalgebra::proptest::matrix;
use proptest::prelude::*;
use nalgebra::DMatrix;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::csc::CscMatrix;
use crate::common::csc_strategy;
use nalgebra::proptest::matrix;
use nalgebra::DMatrix;
use nalgebra_sparse::convert::serial::{
convert_coo_csc, convert_coo_csr, convert_coo_dense, convert_csc_coo, convert_csc_csr,
convert_csc_dense, convert_csr_coo, convert_csr_csc, convert_csr_dense, convert_dense_coo,
convert_dense_csc, convert_dense_csr,
};
use nalgebra_sparse::coo::CooMatrix;
use nalgebra_sparse::csc::CscMatrix;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::proptest::{coo_no_duplicates, coo_with_duplicates, csc, csr};
use proptest::prelude::*;
#[test]
fn test_convert_dense_coo() {
@ -41,15 +40,16 @@ fn test_convert_dense_coo() {
// Here we implicitly test that the coo matrix is indeed constructed from column-major
// iteration of the dense matrix.
let dense = DMatrix::from_row_slice(2, 3, entries);
let coo_no_dup = CooMatrix::try_from_triplets(2, 3,
vec![0, 1, 0],
vec![0, 1, 2],
vec![1, 5, 3])
let coo_no_dup =
CooMatrix::try_from_triplets(2, 3, vec![0, 1, 0], vec![0, 1, 2], vec![1, 5, 3])
.unwrap();
let coo_dup = CooMatrix::try_from_triplets(2, 3,
let coo_dup = CooMatrix::try_from_triplets(
2,
3,
vec![0, 1, 0, 1],
vec![0, 1, 2, 1],
vec![1, -2, 3, 7])
vec![1, -2, 3, 7],
)
.unwrap();
assert_eq!(CooMatrix::from(&dense), coo_no_dup);
@ -76,8 +76,9 @@ fn test_convert_coo_csr() {
4,
vec![0, 1, 2, 5],
vec![1, 3, 0, 2, 3],
vec![2, 4, 1, 1, 2]
).unwrap();
vec![2, 4, 1, 1, 2],
)
.unwrap();
assert_eq!(convert_coo_csr(&coo), expected_csr);
}
@ -101,8 +102,9 @@ fn test_convert_coo_csr() {
4,
vec![0, 1, 2, 5],
vec![1, 3, 0, 2, 3],
vec![5, 4, 1, 1, 4]
).unwrap();
vec![5, 4, 1, 1, 4],
)
.unwrap();
assert_eq!(convert_coo_csr(&coo), expected_csr);
}
@ -115,16 +117,18 @@ fn test_convert_csr_coo() {
4,
vec![0, 1, 2, 5],
vec![1, 3, 0, 2, 3],
vec![5, 4, 1, 1, 4]
).unwrap();
vec![5, 4, 1, 1, 4],
)
.unwrap();
let expected_coo = CooMatrix::try_from_triplets(
3,
4,
vec![0, 1, 2, 2, 2],
vec![1, 3, 0, 2, 3],
vec![5, 4, 1, 1, 4]
).unwrap();
vec![5, 4, 1, 1, 4],
)
.unwrap();
assert_eq!(convert_csr_coo(&csr), expected_coo);
}
@ -148,8 +152,9 @@ fn test_convert_coo_csc() {
4,
vec![0, 1, 2, 3, 5],
vec![2, 0, 2, 1, 2],
vec![1, 2, 1, 4, 2]
).unwrap();
vec![1, 2, 1, 4, 2],
)
.unwrap();
assert_eq!(convert_coo_csc(&coo), expected_csc);
}
@ -173,8 +178,9 @@ fn test_convert_coo_csc() {
4,
vec![0, 1, 2, 3, 5],
vec![2, 0, 2, 1, 2],
vec![1, 5, 1, 4, 4]
).unwrap();
vec![1, 5, 1, 4, 4],
)
.unwrap();
assert_eq!(convert_coo_csc(&coo), expected_csc);
}
@ -187,16 +193,18 @@ fn test_convert_csc_coo() {
4,
vec![0, 1, 2, 3, 5],
vec![2, 0, 2, 1, 2],
vec![1, 2, 1, 4, 2]
).unwrap();
vec![1, 2, 1, 4, 2],
)
.unwrap();
let expected_coo = CooMatrix::try_from_triplets(
3,
4,
vec![2, 0, 2, 1, 2],
vec![0, 1, 2, 3, 3],
vec![1, 2, 1, 4, 2]
).unwrap();
vec![1, 2, 1, 4, 2],
)
.unwrap();
assert_eq!(convert_csc_coo(&csc), expected_coo);
}
@ -209,7 +217,8 @@ fn test_convert_csr_csc_bidirectional() {
vec![0, 3, 4, 6],
vec![1, 2, 3, 0, 1, 3],
vec![5, 3, 2, 2, 1, 4],
).unwrap();
)
.unwrap();
let csc = CscMatrix::try_from_csc_data(
3,
@ -217,7 +226,8 @@ fn test_convert_csr_csc_bidirectional() {
vec![0, 1, 3, 4, 6],
vec![1, 0, 2, 0, 0, 2],
vec![2, 5, 1, 3, 2, 4],
).unwrap();
)
.unwrap();
assert_eq!(convert_csr_csc(&csr), csc);
assert_eq!(convert_csc_csr(&csc), csr);
@ -231,7 +241,8 @@ fn test_convert_csr_dense_bidirectional() {
vec![0, 3, 4, 6],
vec![1, 2, 3, 0, 1, 3],
vec![5, 3, 2, 2, 1, 4],
).unwrap();
)
.unwrap();
#[rustfmt::skip]
let dense = DMatrix::from_row_slice(3, 4, &[
@ -252,7 +263,8 @@ fn test_convert_csc_dense_bidirectional() {
vec![0, 1, 3, 4, 6],
vec![1, 0, 2, 0, 0, 2],
vec![2, 5, 1, 3, 2, 4],
).unwrap();
)
.unwrap();
#[rustfmt::skip]
let dense = DMatrix::from_row_slice(3, 4, &[

View File

@ -1,7 +1,7 @@
use nalgebra_sparse::{SparseFormatErrorKind};
use nalgebra_sparse::coo::CooMatrix;
use nalgebra::DMatrix;
use crate::assert_panics;
use nalgebra::DMatrix;
use nalgebra_sparse::coo::CooMatrix;
use nalgebra_sparse::SparseFormatErrorKind;
#[test]
fn coo_construction_for_valid_data() {
@ -10,8 +10,8 @@ fn coo_construction_for_valid_data() {
{
// Zero matrix
let coo = CooMatrix::<i32>::try_from_triplets(3, 2, Vec::new(), Vec::new(), Vec::new())
.unwrap();
let coo =
CooMatrix::<i32>::try_from_triplets(3, 2, Vec::new(), Vec::new(), Vec::new()).unwrap();
assert_eq!(coo.nrows(), 3);
assert_eq!(coo.ncols(), 2);
assert!(coo.triplet_iter().next().is_none());
@ -27,8 +27,8 @@ fn coo_construction_for_valid_data() {
let i = vec![0, 1, 0, 0, 2];
let j = vec![0, 2, 1, 3, 3];
let v = vec![2, 3, 7, 3, 1];
let coo = CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone())
.unwrap();
let coo =
CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap();
assert_eq!(coo.nrows(), 3);
assert_eq!(coo.ncols(), 5);
@ -59,8 +59,8 @@ fn coo_construction_for_valid_data() {
let i = vec![0, 1, 0, 0, 0, 0, 2, 1];
let j = vec![0, 2, 0, 1, 0, 3, 3, 2];
let v = vec![2, 3, 4, 7, 1, 3, 1, 5];
let coo = CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone())
.unwrap();
let coo =
CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap();
assert_eq!(coo.nrows(), 3);
assert_eq!(coo.ncols(), 5);
@ -92,25 +92,37 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
{
// 0x0 matrix
let result = CooMatrix::<i32>::try_from_triplets(0, 0, vec![0], vec![0], vec![2]);
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
}
{
// 1x1 matrix, row out of bounds
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![0], vec![2]);
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
}
{
// 1x1 matrix, col out of bounds
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![0], vec![1], vec![2]);
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
}
{
// 1x1 matrix, row and col out of bounds
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![1], vec![2]);
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
}
{
@ -119,7 +131,10 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
let j = vec![0, 2, 1, 3, 3];
let v = vec![2, 3, 7, 3, 1];
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
}
{
@ -128,7 +143,10 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
let j = vec![0, 2, 1, 5, 3];
let v = vec![2, 3, 7, 3, 1];
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
}
}
@ -137,16 +155,55 @@ fn coo_try_from_triplets_panics_on_mismatched_vectors() {
// Check that try_from_triplets panics when the triplet vectors have different lengths
macro_rules! assert_errs {
($result:expr) => {
assert!(matches!($result.unwrap_err().kind(), SparseFormatErrorKind::InvalidStructure))
}
assert!(matches!(
$result.unwrap_err().kind(),
SparseFormatErrorKind::InvalidStructure
))
};
}
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1, 2], vec![0], vec![0]));
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1], vec![0, 0], vec![0]));
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1], vec![0], vec![0, 1]));
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1, 2], vec![0, 1], vec![0]));
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1], vec![0, 1], vec![0, 1]));
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1, 1], vec![0], vec![0, 1]));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1, 2],
vec![0],
vec![0]
));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1],
vec![0, 0],
vec![0]
));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1],
vec![0],
vec![0, 1]
));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1, 2],
vec![0, 1],
vec![0]
));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1],
vec![0, 1],
vec![0, 1]
));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1, 1],
vec![0],
vec![0, 1]
));
}
#[test]
@ -157,10 +214,16 @@ fn coo_push_valid_entries() {
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1)]);
coo.push(0, 0, 2);
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1), (0, 0, &2)]);
assert_eq!(
coo.triplet_iter().collect::<Vec<_>>(),
vec![(0, 0, &1), (0, 0, &2)]
);
coo.push(2, 2, 3);
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1), (0, 0, &2), (2, 2, &3)]);
assert_eq!(
coo.triplet_iter().collect::<Vec<_>>(),
vec![(0, 0, &1), (0, 0, &2), (2, 2, &3)]
);
}
#[test]

View File

@ -1,6 +1,6 @@
use nalgebra::DMatrix;
use nalgebra_sparse::csc::CscMatrix;
use nalgebra_sparse::SparseFormatErrorKind;
use nalgebra::DMatrix;
use proptest::prelude::*;
use proptest::sample::subsequence;
@ -42,7 +42,10 @@ fn csc_matrix_valid_data() {
assert_eq!(matrix.col_mut(0).row_indices(), &[]);
assert_eq!(matrix.col_mut(0).values(), &[]);
assert_eq!(matrix.col_mut(0).values_mut(), &[]);
assert_eq!(matrix.col_mut(0).rows_and_values_mut(), ([].as_ref(), [].as_mut()));
assert_eq!(
matrix.col_mut(0).rows_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.col(1).nrows(), 2);
assert_eq!(matrix.col(1).nnz(), 0);
@ -53,7 +56,10 @@ fn csc_matrix_valid_data() {
assert_eq!(matrix.col_mut(1).row_indices(), &[]);
assert_eq!(matrix.col_mut(1).values(), &[]);
assert_eq!(matrix.col_mut(1).values_mut(), &[]);
assert_eq!(matrix.col_mut(1).rows_and_values_mut(), ([].as_ref(), [].as_mut()));
assert_eq!(
matrix.col_mut(1).rows_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.col(2).nrows(), 2);
assert_eq!(matrix.col(2).nnz(), 0);
@ -64,7 +70,10 @@ fn csc_matrix_valid_data() {
assert_eq!(matrix.col_mut(2).row_indices(), &[]);
assert_eq!(matrix.col_mut(2).values(), &[]);
assert_eq!(matrix.col_mut(2).values_mut(), &[]);
assert_eq!(matrix.col_mut(2).rows_and_values_mut(), ([].as_ref(), [].as_mut()));
assert_eq!(
matrix.col_mut(2).rows_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert!(matrix.get_col(3).is_none());
assert!(matrix.get_col_mut(3).is_none());
@ -81,11 +90,9 @@ fn csc_matrix_valid_data() {
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let mut matrix = CscMatrix::try_from_csc_data(6,
3,
offsets.clone(),
indices.clone(),
values.clone()).unwrap();
let mut matrix =
CscMatrix::try_from_csc_data(6, 3, offsets.clone(), indices.clone(), values.clone())
.unwrap();
assert_eq!(matrix.nrows(), 6);
assert_eq!(matrix.ncols(), 3);
@ -95,10 +102,20 @@ fn csc_matrix_valid_data() {
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
let expected_triplets = vec![(0, 0, 0), (5, 0, 1), (1, 2, 2), (2, 2, 3), (3, 2, 4)];
assert_eq!(matrix.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(),
expected_triplets);
assert_eq!(matrix.triplet_iter_mut().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(),
expected_triplets);
assert_eq!(
matrix
.triplet_iter()
.map(|(i, j, v)| (i, j, *v))
.collect::<Vec<_>>(),
expected_triplets
);
assert_eq!(
matrix
.triplet_iter_mut()
.map(|(i, j, v)| (i, j, *v))
.collect::<Vec<_>>(),
expected_triplets
);
assert_eq!(matrix.col(0).nrows(), 6);
assert_eq!(matrix.col(0).nnz(), 2);
@ -109,7 +126,10 @@ fn csc_matrix_valid_data() {
assert_eq!(matrix.col_mut(0).row_indices(), &[0, 5]);
assert_eq!(matrix.col_mut(0).values(), &[0, 1]);
assert_eq!(matrix.col_mut(0).values_mut(), &[0, 1]);
assert_eq!(matrix.col_mut(0).rows_and_values_mut(), ([0, 5].as_ref(), [0, 1].as_mut()));
assert_eq!(
matrix.col_mut(0).rows_and_values_mut(),
([0, 5].as_ref(), [0, 1].as_mut())
);
assert_eq!(matrix.col(1).nrows(), 6);
assert_eq!(matrix.col(1).nnz(), 0);
@ -120,7 +140,10 @@ fn csc_matrix_valid_data() {
assert_eq!(matrix.col_mut(1).row_indices(), &[]);
assert_eq!(matrix.col_mut(1).values(), &[]);
assert_eq!(matrix.col_mut(1).values_mut(), &[]);
assert_eq!(matrix.col_mut(1).rows_and_values_mut(), ([].as_ref(), [].as_mut()));
assert_eq!(
matrix.col_mut(1).rows_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.col(2).nrows(), 6);
assert_eq!(matrix.col(2).nnz(), 3);
@ -131,7 +154,10 @@ fn csc_matrix_valid_data() {
assert_eq!(matrix.col_mut(2).row_indices(), &[1, 2, 3]);
assert_eq!(matrix.col_mut(2).values(), &[2, 3, 4]);
assert_eq!(matrix.col_mut(2).values_mut(), &[2, 3, 4]);
assert_eq!(matrix.col_mut(2).rows_and_values_mut(), ([1, 2, 3].as_ref(), [2, 3, 4].as_mut()));
assert_eq!(
matrix.col_mut(2).rows_and_values_mut(),
([1, 2, 3].as_ref(), [2, 3, 4].as_mut())
);
assert!(matrix.get_col(3).is_none());
assert!(matrix.get_col_mut(3).is_none());
@ -146,11 +172,13 @@ fn csc_matrix_valid_data() {
#[test]
fn csc_matrix_try_from_invalid_csc_data() {
{
// Empty offset array (invalid length)
let matrix = CscMatrix::try_from_csc_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
@ -160,7 +188,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
@ -169,7 +200,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
@ -178,7 +212,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
@ -187,7 +224,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
@ -196,7 +236,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
let indices = vec![0, 1, 2, 3, 4];
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
@ -205,7 +248,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
let indices = vec![0, 2, 3, 1, 4];
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
@ -214,7 +260,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
let indices = vec![0, 6, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::IndexOutOfBounds);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::IndexOutOfBounds
);
}
{
@ -223,9 +272,11 @@ fn csc_matrix_try_from_invalid_csc_data() {
let indices = vec![0, 5, 2, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::DuplicateEntry);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::DuplicateEntry
);
}
}
#[test]
@ -239,11 +290,7 @@ fn csc_disassemble_avoids_clone_when_owned() {
let offsets_ptr = offsets.as_ptr();
let indices_ptr = indices.as_ptr();
let values_ptr = values.as_ptr();
let matrix = CscMatrix::try_from_csc_data(6,
3,
offsets,
indices,
values).unwrap();
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values).unwrap();
let (offsets, indices, values) = matrix.disassemble();
assert_eq!(offsets.as_ptr(), offsets_ptr);

View File

@ -1,6 +1,6 @@
use nalgebra::DMatrix;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::SparseFormatErrorKind;
use nalgebra::DMatrix;
use proptest::prelude::*;
use proptest::sample::subsequence;
@ -9,7 +9,6 @@ use crate::common::csr_strategy;
use std::collections::HashSet;
#[test]
fn csr_matrix_valid_data() {
// Construct matrix from valid data and check that selected methods return results
@ -43,7 +42,10 @@ fn csr_matrix_valid_data() {
assert_eq!(matrix.row_mut(0).col_indices(), &[]);
assert_eq!(matrix.row_mut(0).values(), &[]);
assert_eq!(matrix.row_mut(0).values_mut(), &[]);
assert_eq!(matrix.row_mut(0).cols_and_values_mut(), ([].as_ref(), [].as_mut()));
assert_eq!(
matrix.row_mut(0).cols_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.row(1).ncols(), 2);
assert_eq!(matrix.row(1).nnz(), 0);
@ -54,7 +56,10 @@ fn csr_matrix_valid_data() {
assert_eq!(matrix.row_mut(1).col_indices(), &[]);
assert_eq!(matrix.row_mut(1).values(), &[]);
assert_eq!(matrix.row_mut(1).values_mut(), &[]);
assert_eq!(matrix.row_mut(1).cols_and_values_mut(), ([].as_ref(), [].as_mut()));
assert_eq!(
matrix.row_mut(1).cols_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.row(2).ncols(), 2);
assert_eq!(matrix.row(2).nnz(), 0);
@ -65,7 +70,10 @@ fn csr_matrix_valid_data() {
assert_eq!(matrix.row_mut(2).col_indices(), &[]);
assert_eq!(matrix.row_mut(2).values(), &[]);
assert_eq!(matrix.row_mut(2).values_mut(), &[]);
assert_eq!(matrix.row_mut(2).cols_and_values_mut(), ([].as_ref(), [].as_mut()));
assert_eq!(
matrix.row_mut(2).cols_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert!(matrix.get_row(3).is_none());
assert!(matrix.get_row_mut(3).is_none());
@ -82,11 +90,9 @@ fn csr_matrix_valid_data() {
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let mut matrix = CsrMatrix::try_from_csr_data(3,
6,
offsets.clone(),
indices.clone(),
values.clone()).unwrap();
let mut matrix =
CsrMatrix::try_from_csr_data(3, 6, offsets.clone(), indices.clone(), values.clone())
.unwrap();
assert_eq!(matrix.nrows(), 3);
assert_eq!(matrix.ncols(), 6);
@ -96,10 +102,20 @@ fn csr_matrix_valid_data() {
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
let expected_triplets = vec![(0, 0, 0), (0, 5, 1), (2, 1, 2), (2, 2, 3), (2, 3, 4)];
assert_eq!(matrix.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(),
expected_triplets);
assert_eq!(matrix.triplet_iter_mut().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(),
expected_triplets);
assert_eq!(
matrix
.triplet_iter()
.map(|(i, j, v)| (i, j, *v))
.collect::<Vec<_>>(),
expected_triplets
);
assert_eq!(
matrix
.triplet_iter_mut()
.map(|(i, j, v)| (i, j, *v))
.collect::<Vec<_>>(),
expected_triplets
);
assert_eq!(matrix.row(0).ncols(), 6);
assert_eq!(matrix.row(0).nnz(), 2);
@ -110,7 +126,10 @@ fn csr_matrix_valid_data() {
assert_eq!(matrix.row_mut(0).col_indices(), &[0, 5]);
assert_eq!(matrix.row_mut(0).values(), &[0, 1]);
assert_eq!(matrix.row_mut(0).values_mut(), &[0, 1]);
assert_eq!(matrix.row_mut(0).cols_and_values_mut(), ([0, 5].as_ref(), [0, 1].as_mut()));
assert_eq!(
matrix.row_mut(0).cols_and_values_mut(),
([0, 5].as_ref(), [0, 1].as_mut())
);
assert_eq!(matrix.row(1).ncols(), 6);
assert_eq!(matrix.row(1).nnz(), 0);
@ -121,7 +140,10 @@ fn csr_matrix_valid_data() {
assert_eq!(matrix.row_mut(1).col_indices(), &[]);
assert_eq!(matrix.row_mut(1).values(), &[]);
assert_eq!(matrix.row_mut(1).values_mut(), &[]);
assert_eq!(matrix.row_mut(1).cols_and_values_mut(), ([].as_ref(), [].as_mut()));
assert_eq!(
matrix.row_mut(1).cols_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.row(2).ncols(), 6);
assert_eq!(matrix.row(2).nnz(), 3);
@ -132,7 +154,10 @@ fn csr_matrix_valid_data() {
assert_eq!(matrix.row_mut(2).col_indices(), &[1, 2, 3]);
assert_eq!(matrix.row_mut(2).values(), &[2, 3, 4]);
assert_eq!(matrix.row_mut(2).values_mut(), &[2, 3, 4]);
assert_eq!(matrix.row_mut(2).cols_and_values_mut(), ([1, 2, 3].as_ref(), [2, 3, 4].as_mut()));
assert_eq!(
matrix.row_mut(2).cols_and_values_mut(),
([1, 2, 3].as_ref(), [2, 3, 4].as_mut())
);
assert!(matrix.get_row(3).is_none());
assert!(matrix.get_row_mut(3).is_none());
@ -147,11 +172,13 @@ fn csr_matrix_valid_data() {
#[test]
fn csr_matrix_try_from_invalid_csr_data() {
{
// Empty offset array (invalid length)
let matrix = CsrMatrix::try_from_csr_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
@ -161,7 +188,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
@ -170,7 +200,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
@ -179,7 +212,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
@ -188,7 +224,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
@ -197,7 +236,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
let indices = vec![0, 1, 2, 3, 4];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
@ -206,7 +248,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
let indices = vec![0, 2, 3, 1, 4];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
@ -215,7 +260,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
let indices = vec![0, 6, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::IndexOutOfBounds);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::IndexOutOfBounds
);
}
{
@ -224,9 +272,11 @@ fn csr_matrix_try_from_invalid_csr_data() {
let indices = vec![0, 5, 2, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::DuplicateEntry);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::DuplicateEntry
);
}
}
#[test]
@ -240,11 +290,7 @@ fn csr_disassemble_avoids_clone_when_owned() {
let offsets_ptr = offsets.as_ptr();
let indices_ptr = indices.as_ptr();
let values_ptr = values.as_ptr();
let matrix = CsrMatrix::try_from_csr_data(3,
6,
offsets,
indices,
values).unwrap();
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values).unwrap();
let (offsets, indices, values) = matrix.disassemble();
assert_eq!(offsets.as_ptr(), offsets_ptr);

View File

@ -1,8 +1,8 @@
mod coo;
mod cholesky;
mod convert_serial;
mod coo;
mod csc;
mod csr;
mod ops;
mod pattern;
mod csr;
mod csc;
mod convert_serial;
mod proptest;

View File

@ -1,13 +1,19 @@
use crate::common::{csc_strategy, csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ, PROPTEST_I32_VALUE_STRATEGY, non_zero_i32_value_strategy, value_strategy};
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spadd_csr_prealloc, spadd_csc_prealloc, spmm_csr_prealloc, spmm_csc_prealloc, spsolve_csc_lower_triangular, spmm_csr_pattern};
use nalgebra_sparse::ops::{Op};
use nalgebra_sparse::csr::CsrMatrix;
use crate::common::{
csc_strategy, csr_strategy, non_zero_i32_value_strategy, value_strategy,
PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
};
use nalgebra_sparse::csc::CscMatrix;
use nalgebra_sparse::proptest::{csc, csr, sparsity_pattern};
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::ops::serial::{
spadd_csc_prealloc, spadd_csr_prealloc, spadd_pattern, spmm_csc_dense, spmm_csc_prealloc,
spmm_csr_dense, spmm_csr_pattern, spmm_csr_prealloc, spsolve_csc_lower_triangular,
};
use nalgebra_sparse::ops::Op;
use nalgebra_sparse::pattern::SparsityPattern;
use nalgebra_sparse::proptest::{csc, csr, sparsity_pattern};
use nalgebra::{DMatrix, Scalar, DMatrixSliceMut, DMatrixSlice};
use nalgebra::proptest::{matrix, vector};
use nalgebra::{DMatrix, DMatrixSlice, DMatrixSliceMut, Scalar};
use proptest::prelude::*;
@ -17,19 +23,15 @@ use std::panic::catch_unwind;
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
let boolean_csr = CsrMatrix::try_from_pattern_and_values(
pattern.clone(),
vec![1; pattern.nnz()])
.unwrap();
let boolean_csr =
CsrMatrix::try_from_pattern_and_values(pattern.clone(), vec![1; pattern.nnz()]).unwrap();
DMatrix::from(&boolean_csr)
}
/// Represents the sparsity pattern of a CSC matrix as a dense matrix with 0/1
fn dense_csc_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
let boolean_csc = CscMatrix::try_from_pattern_and_values(
pattern.clone(),
vec![1; pattern.nnz()])
.unwrap();
let boolean_csc =
CscMatrix::try_from_pattern_and_values(pattern.clone(), vec![1; pattern.nnz()]).unwrap();
DMatrix::from(&boolean_csc)
}
@ -62,14 +64,23 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
let trans_strategy = trans_strategy();
let c_matrix_strategy = matrix(value_strategy.clone(), c_rows, c_cols);
(c_matrix_strategy, common_dim, trans_strategy.clone(), trans_strategy.clone())
(
c_matrix_strategy,
common_dim,
trans_strategy.clone(),
trans_strategy.clone(),
)
.prop_flat_map(move |(c, common_dim, trans_a, trans_b)| {
let a_shape =
if trans_a { (common_dim, c.nrows()) }
else { (c.nrows(), common_dim) };
let b_shape =
if trans_b { (c.ncols(), common_dim) }
else { (common_dim, c.ncols()) };
let a_shape = if trans_a {
(common_dim, c.nrows())
} else {
(c.nrows(), common_dim)
};
let b_shape = if trans_b {
(c.ncols(), common_dim)
} else {
(common_dim, c.ncols())
};
let a = csr(value_strategy.clone(), a_shape.0, a_shape.1, max_nnz);
let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1);
@ -78,29 +89,35 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
let beta = value_strategy.clone();
(Just(c), beta, alpha, Just(trans_a), a, Just(trans_b), b)
}).prop_map(|(c, beta, alpha, trans_a, a, trans_b, b)| {
SpmmCsrDenseArgs {
})
.prop_map(
|(c, beta, alpha, trans_a, a, trans_b, b)| SpmmCsrDenseArgs {
c,
beta,
alpha,
a: if trans_a { Op::Transpose(a) } else { Op::NoOp(a) },
b: if trans_b { Op::Transpose(b) } else { Op::NoOp(b) },
}
})
a: if trans_a {
Op::Transpose(a)
} else {
Op::NoOp(a)
},
b: if trans_b {
Op::Transpose(b)
} else {
Op::NoOp(b)
},
},
)
}
/// Returns matrices C, A and B with compatible dimensions such that it can be used
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
fn spmm_csc_dense_args_strategy() -> impl Strategy<Value = SpmmCscDenseArgs<i32>> {
spmm_csr_dense_args_strategy()
.prop_map(|args| {
SpmmCscDenseArgs {
spmm_csr_dense_args_strategy().prop_map(|args| SpmmCscDenseArgs {
c: args.c,
beta: args.beta,
alpha: args.alpha,
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
b: args.b
}
b: args.b,
})
}
@ -131,28 +148,46 @@ fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>>
let c_values = vec![value_strategy.clone(); c_pattern.nnz()];
let alpha = value_strategy.clone();
let beta = value_strategy.clone();
(Just(c_pattern), Just(a_pattern), c_values, a_values, alpha, beta, trans_strategy())
}).prop_map(|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| {
(
Just(c_pattern),
Just(a_pattern),
c_values,
a_values,
alpha,
beta,
trans_strategy(),
)
})
.prop_map(
|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| {
let c = CsrMatrix::try_from_pattern_and_values(c_pattern, c_values).unwrap();
let a = CsrMatrix::try_from_pattern_and_values(a_pattern, a_values).unwrap();
let a = if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) };
let a = if trans_a {
Op::Transpose(a.transpose())
} else {
Op::NoOp(a)
};
SpaddCsrArgs { c, beta, alpha, a }
})
},
)
}
fn spadd_csc_prealloc_args_strategy() -> impl Strategy<Value = SpaddCscArgs<i32>> {
spadd_csr_prealloc_args_strategy()
.prop_map(|args| SpaddCscArgs {
spadd_csr_prealloc_args_strategy().prop_map(|args| SpaddCscArgs {
c: CscMatrix::from(&args.c),
beta: args.beta,
alpha: args.alpha,
a: args.a.map_same_op(|a| CscMatrix::from(&a))
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
})
}
fn dense_strategy() -> impl Strategy<Value = DMatrix<i32>> {
matrix(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM)
matrix(
PROPTEST_I32_VALUE_STRATEGY,
PROPTEST_MATRIX_DIM,
PROPTEST_MATRIX_DIM,
)
}
fn trans_strategy() -> impl Strategy<Value = bool> + Clone {
@ -163,11 +198,12 @@ fn trans_strategy() -> impl Strategy<Value=bool> + Clone {
/// values.
fn op_strategy<S: Strategy>(strategy: S) -> impl Strategy<Value = Op<S::Value>> {
let is_transposed = proptest::bool::ANY;
(strategy, is_transposed)
.prop_map(|(obj, is_trans)| if is_trans {
(strategy, is_transposed).prop_map(|(obj, is_trans)| {
if is_trans {
Op::Transpose(obj)
} else {
Op::NoOp(obj)
}
})
}
@ -177,8 +213,7 @@ fn pattern_strategy() -> impl Strategy<Value=SparsityPattern> {
/// Constructs pairs (a, b) where a and b have the same dimensions
fn spadd_pattern_strategy() -> impl Strategy<Value = (SparsityPattern, SparsityPattern)> {
pattern_strategy()
.prop_flat_map(|a| {
pattern_strategy().prop_flat_map(|a| {
let b = sparsity_pattern(a.major_dim(), a.minor_dim(), PROPTEST_MAX_NNZ);
(Just(a), b)
})
@ -186,8 +221,7 @@ fn spadd_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPat
/// Constructs pairs (a, b) where a and b have compatible dimensions for a matrix product
fn spmm_csr_pattern_strategy() -> impl Strategy<Value = (SparsityPattern, SparsityPattern)> {
pattern_strategy()
.prop_flat_map(|a| {
pattern_strategy().prop_flat_map(|a| {
let b = sparsity_pattern(a.minor_dim(), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
(Just(a), b)
})
@ -218,48 +252,55 @@ fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()];
let c_pattern = spmm_csr_pattern(&a_pattern, &b_pattern);
let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()];
let a = a_values.prop_map(move |values|
CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap());
let b = b_values.prop_map(move |values|
CsrMatrix::try_from_pattern_and_values(b_pattern.clone(), values).unwrap());
let c = c_values.prop_map(move |values|
CsrMatrix::try_from_pattern_and_values(c_pattern.clone(), values).unwrap());
let a = a_values.prop_map(move |values| {
CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap()
});
let b = b_values.prop_map(move |values| {
CsrMatrix::try_from_pattern_and_values(b_pattern.clone(), values).unwrap()
});
let c = c_values.prop_map(move |values| {
CsrMatrix::try_from_pattern_and_values(c_pattern.clone(), values).unwrap()
});
let alpha = PROPTEST_I32_VALUE_STRATEGY;
let beta = PROPTEST_I32_VALUE_STRATEGY;
(c, beta, alpha, trans_strategy(), a, trans_strategy(), b)
})
.prop_map(|(c, beta, alpha, trans_a, a, trans_b, b)| {
SpmmCsrArgs::<i32> {
.prop_map(
|(c, beta, alpha, trans_a, a, trans_b, b)| SpmmCsrArgs::<i32> {
c,
beta,
alpha,
a: if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) },
b: if trans_b { Op::Transpose(b.transpose()) } else { Op::NoOp(b) }
}
})
a: if trans_a {
Op::Transpose(a.transpose())
} else {
Op::NoOp(a)
},
b: if trans_b {
Op::Transpose(b.transpose())
} else {
Op::NoOp(b)
},
},
)
}
fn spmm_csc_prealloc_args_strategy() -> impl Strategy<Value = SpmmCscArgs<i32>> {
// Note: Converting from CSR is simple, but might be significantly slower than
// writing a common implementation that can be shared between CSR and CSC args
spmm_csr_prealloc_args_strategy()
.prop_map(|args| {
SpmmCscArgs {
spmm_csr_prealloc_args_strategy().prop_map(|args| SpmmCscArgs {
c: CscMatrix::from(&args.c),
beta: args.beta,
alpha: args.alpha,
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
b: args.b.map_same_op(|b| CscMatrix::from(&b))
}
b: args.b.map_same_op(|b| CscMatrix::from(&b)),
})
}
fn csc_invertible_diagonal() -> impl Strategy<Value = CscMatrix<f64>> {
let non_zero_values = value_strategy::<f64>()
.prop_filter("Only non-zeros values accepted", |x| x != &0.0);
let non_zero_values =
value_strategy::<f64>().prop_filter("Only non-zeros values accepted", |x| x != &0.0);
vector(non_zero_values, PROPTEST_MATRIX_DIM)
.prop_map(|d| {
vector(non_zero_values, PROPTEST_MATRIX_DIM).prop_map(|d| {
let mut matrix = CscMatrix::identity(d.len());
matrix.values_mut().clone_from_slice(&d.as_slice());
matrix
@ -267,9 +308,13 @@ fn csc_invertible_diagonal() -> impl Strategy<Value=CscMatrix<f64>> {
}
fn csc_square_with_non_zero_diagonals() -> impl Strategy<Value = CscMatrix<f64>> {
csc_invertible_diagonal()
.prop_flat_map(|d| {
csc(value_strategy::<f64>(), d.nrows(), d.nrows(), PROPTEST_MAX_NNZ)
csc_invertible_diagonal().prop_flat_map(|d| {
csc(
value_strategy::<f64>(),
d.nrows(),
d.nrows(),
PROPTEST_MAX_NNZ,
)
.prop_map(move |mut c| {
for (i, j, v) in c.triplet_iter_mut() {
if i == j {
@ -285,12 +330,13 @@ fn csc_square_with_non_zero_diagonals() -> impl Strategy<Value=CscMatrix<f64>> {
}
/// Helper function to help us call dense GEMM with our `Op` type
fn dense_gemm<'a>(beta: i32,
fn dense_gemm<'a>(
beta: i32,
c: impl Into<DMatrixSliceMut<'a, i32>>,
alpha: i32,
a: Op<impl Into<DMatrixSlice<'a, i32>>>,
b: Op<impl Into<DMatrixSlice<'a, i32>>>)
{
b: Op<impl Into<DMatrixSlice<'a, i32>>>,
) {
let mut c = c.into();
let a = a.convert();
let b = b.convert();
@ -300,7 +346,7 @@ fn dense_gemm<'a>(beta: i32,
(NoOp(a), NoOp(b)) => c.gemm(alpha, &a, &b, beta),
(Transpose(a), NoOp(b)) => c.gemm(alpha, &a.transpose(), &b, beta),
(NoOp(a), Transpose(b)) => c.gemm(alpha, &a, &b.transpose(), beta),
(Transpose(a), Transpose(b)) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta)
(Transpose(a), Transpose(b)) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta),
}
}

View File

@ -7,10 +7,8 @@ fn sparsity_pattern_valid_data() {
{
// A pattern with zero explicitly stored entries
let pattern = SparsityPattern::try_from_offsets_and_indices(3,
2,
vec![0, 0, 0, 0],
Vec::new())
let pattern =
SparsityPattern::try_from_offsets_and_indices(3, 2, vec![0, 0, 0, 0], Vec::new())
.unwrap();
assert_eq!(pattern.major_dim(), 3);
@ -46,8 +44,10 @@ fn sparsity_pattern_valid_data() {
assert_eq!(pattern.lane(0), &[0, 5]);
assert_eq!(pattern.lane(1), &[]);
assert_eq!(pattern.lane(2), &[1, 2, 3]);
assert_eq!(pattern.entries().collect::<Vec<_>>(),
vec![(0, 0), (0, 5), (2, 1), (2, 2), (2, 3)]);
assert_eq!(
pattern.entries().collect::<Vec<_>>(),
vec![(0, 0), (0, 5), (2, 1), (2, 2), (2, 3)]
);
let (offsets2, indices2) = pattern.disassemble();
assert_eq!(offsets2, offsets);
@ -60,7 +60,10 @@ fn sparsity_pattern_try_from_invalid_data() {
{
// Empty offset array (invalid length)
let pattern = SparsityPattern::try_from_offsets_and_indices(0, 0, Vec::new(), Vec::new());
assert_eq!(pattern, Err(SparsityPatternFormatError::InvalidOffsetArrayLength));
assert_eq!(
pattern,
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
);
}
{
@ -69,7 +72,10 @@ fn sparsity_pattern_try_from_invalid_data() {
let indices = vec![0, 1, 2, 3, 5];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetArrayLength)));
assert!(matches!(
pattern,
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
));
}
{
@ -77,7 +83,10 @@ fn sparsity_pattern_try_from_invalid_data() {
let offsets = vec![1, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetFirstLast)));
assert!(matches!(
pattern,
Err(SparsityPatternFormatError::InvalidOffsetFirstLast)
));
}
{
@ -85,7 +94,10 @@ fn sparsity_pattern_try_from_invalid_data() {
let offsets = vec![0, 2, 2, 4];
let indices = vec![0, 5, 1, 2, 3];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetFirstLast)));
assert!(matches!(
pattern,
Err(SparsityPatternFormatError::InvalidOffsetFirstLast)
));
}
{
@ -93,7 +105,10 @@ fn sparsity_pattern_try_from_invalid_data() {
let offsets = vec![0, 2, 2];
let indices = vec![0, 5, 1, 2, 3];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetArrayLength)));
assert!(matches!(
pattern,
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
));
}
{
@ -101,7 +116,10 @@ fn sparsity_pattern_try_from_invalid_data() {
let offsets = vec![0, 3, 2, 5];
let indices = vec![0, 1, 2, 3, 4];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert_eq!(pattern, Err(SparsityPatternFormatError::NonmonotonicOffsets));
assert_eq!(
pattern,
Err(SparsityPatternFormatError::NonmonotonicOffsets)
);
}
{
@ -109,7 +127,10 @@ fn sparsity_pattern_try_from_invalid_data() {
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 2, 3, 1, 4];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert_eq!(pattern, Err(SparsityPatternFormatError::NonmonotonicMinorIndices));
assert_eq!(
pattern,
Err(SparsityPatternFormatError::NonmonotonicMinorIndices)
);
}
{
@ -117,7 +138,10 @@ fn sparsity_pattern_try_from_invalid_data() {
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 6, 1, 2, 3];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert_eq!(pattern, Err(SparsityPatternFormatError::MinorIndexOutOfBounds));
assert_eq!(
pattern,
Err(SparsityPatternFormatError::MinorIndexOutOfBounds)
);
}
{

View File

@ -6,25 +6,27 @@ fn coo_no_duplicates_generates_admissible_matrices() {
#[cfg(feature = "slow-tests")]
mod slow {
use nalgebra_sparse::proptest::{coo_with_duplicates, coo_no_duplicates, csr, csc, sparsity_pattern};
use nalgebra::DMatrix;
use nalgebra_sparse::proptest::{
coo_no_duplicates, coo_with_duplicates, csc, csr, sparsity_pattern,
};
use proptest::test_runner::TestRunner;
use proptest::strategy::ValueTree;
use itertools::Itertools;
use proptest::strategy::ValueTree;
use proptest::test_runner::TestRunner;
use proptest::prelude::*;
use nalgebra_sparse::csr::CsrMatrix;
use std::collections::HashSet;
use std::iter::repeat;
use std::ops::RangeInclusive;
use nalgebra_sparse::csr::CsrMatrix;
fn generate_all_possible_matrices(value_range: RangeInclusive<i32>,
fn generate_all_possible_matrices(
value_range: RangeInclusive<i32>,
rows_range: RangeInclusive<usize>,
cols_range: RangeInclusive<usize>)
-> HashSet<DMatrix<i32>>
{
cols_range: RangeInclusive<usize>,
) -> HashSet<DMatrix<i32>> {
// Enumerate all possible combinations
let mut all_combinations = HashSet::new();
for nrows in rows_range {
@ -48,7 +50,11 @@ mod slow {
.take(n_values)
.multi_cartesian_product();
for matrix_values in values_iter {
all_combinations.insert(DMatrix::from_row_slice(nrows, ncols, &matrix_values));
all_combinations.insert(DMatrix::from_row_slice(
nrows,
ncols,
&matrix_values,
));
}
}
}
@ -80,12 +86,14 @@ mod slow {
// Enumerate all possible combinations
let all_combinations = generate_all_possible_matrices(values, rows, cols);
let visited_combinations = sample_matrix_output_space(strategy,
&mut runner,
num_generated_matrices);
let visited_combinations =
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
assert_eq!(visited_combinations.len(), all_combinations.len());
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values.");
assert_eq!(
visited_combinations, all_combinations,
"Did not sample all possible values."
);
}
#[cfg(feature = "slow-tests")]
@ -113,9 +121,8 @@ mod slow {
// `coo_with_duplicates`)
let all_combinations = generate_all_possible_matrices(values, rows, cols);
let visited_combinations = sample_matrix_output_space(strategy,
&mut runner,
num_generated_matrices);
let visited_combinations =
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
// Here we cannot verify that the set of visited combinations is *equal* to
// all possible outcomes with the given constraints, however the
@ -143,12 +150,14 @@ mod slow {
let all_combinations = generate_all_possible_matrices(values, rows, cols);
let visited_combinations = sample_matrix_output_space(strategy,
&mut runner,
num_generated_matrices);
let visited_combinations =
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
assert_eq!(visited_combinations.len(), all_combinations.len());
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values.");
assert_eq!(
visited_combinations, all_combinations,
"Did not sample all possible values."
);
}
#[cfg(feature = "slow-tests")]
@ -169,12 +178,14 @@ mod slow {
let all_combinations = generate_all_possible_matrices(values, rows, cols);
let visited_combinations = sample_matrix_output_space(strategy,
&mut runner,
num_generated_matrices);
let visited_combinations =
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
assert_eq!(visited_combinations.len(), all_combinations.len());
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values.");
assert_eq!(
visited_combinations, all_combinations,
"Did not sample all possible values."
);
}
#[cfg(feature = "slow-tests")]
@ -206,13 +217,14 @@ mod slow {
assert_eq!(visited_patterns, all_possible_patterns);
}
fn sample_matrix_output_space<S>(strategy: S,
fn sample_matrix_output_space<S>(
strategy: S,
runner: &mut TestRunner,
num_samples: usize)
-> HashSet<DMatrix<i32>>
num_samples: usize,
) -> HashSet<DMatrix<i32>>
where
S: Strategy,
DMatrix<i32>: for<'b> From<&'b S::Value>
DMatrix<i32>: for<'b> From<&'b S::Value>,
{
sample_strategy(strategy, runner)
.take(num_samples)
@ -220,8 +232,10 @@ mod slow {
.collect()
}
fn sample_strategy<'a, S: 'a + Strategy>(strategy: S, runner: &'a mut TestRunner)
-> impl 'a + Iterator<Item=S::Value> {
fn sample_strategy<'a, S: 'a + Strategy>(
strategy: S,
runner: &'a mut TestRunner,
) -> impl 'a + Iterator<Item = S::Value> {
repeat(()).map(move |_| {
let tree = strategy
.new_tree(runner)