rustfmt
This commit is contained in:
parent
795d818ae5
commit
7473d54d74
|
@ -1,17 +1,17 @@
|
|||
use crate::coo::CooMatrix;
|
||||
use crate::convert::serial::*;
|
||||
use nalgebra::{Matrix, Scalar, Dim, ClosedAdd, DMatrix};
|
||||
use nalgebra::storage::{Storage};
|
||||
use num_traits::Zero;
|
||||
use crate::csr::CsrMatrix;
|
||||
use crate::coo::CooMatrix;
|
||||
use crate::csc::CscMatrix;
|
||||
use crate::csr::CsrMatrix;
|
||||
use nalgebra::storage::Storage;
|
||||
use nalgebra::{ClosedAdd, DMatrix, Dim, Matrix, Scalar};
|
||||
use num_traits::Zero;
|
||||
|
||||
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CooMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: Storage<T, R, C>
|
||||
S: Storage<T, R, C>,
|
||||
{
|
||||
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
|
||||
convert_dense_coo(matrix)
|
||||
|
@ -29,7 +29,7 @@ where
|
|||
|
||||
impl<'a, T> From<&'a CooMatrix<T>> for CsrMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero + ClosedAdd
|
||||
T: Scalar + Zero + ClosedAdd,
|
||||
{
|
||||
fn from(matrix: &'a CooMatrix<T>) -> Self {
|
||||
convert_coo_csr(matrix)
|
||||
|
@ -38,7 +38,7 @@ where
|
|||
|
||||
impl<'a, T> From<&'a CsrMatrix<T>> for CooMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero + ClosedAdd
|
||||
T: Scalar + Zero + ClosedAdd,
|
||||
{
|
||||
fn from(matrix: &'a CsrMatrix<T>) -> Self {
|
||||
convert_csr_coo(matrix)
|
||||
|
@ -50,7 +50,7 @@ where
|
|||
T: Scalar + Zero,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: Storage<T, R, C>
|
||||
S: Storage<T, R, C>,
|
||||
{
|
||||
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
|
||||
convert_dense_csr(matrix)
|
||||
|
@ -59,7 +59,7 @@ where
|
|||
|
||||
impl<'a, T> From<&'a CsrMatrix<T>> for DMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero + ClosedAdd
|
||||
T: Scalar + Zero + ClosedAdd,
|
||||
{
|
||||
fn from(matrix: &'a CsrMatrix<T>) -> Self {
|
||||
convert_csr_dense(matrix)
|
||||
|
@ -68,7 +68,7 @@ where
|
|||
|
||||
impl<'a, T> From<&'a CooMatrix<T>> for CscMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero + ClosedAdd
|
||||
T: Scalar + Zero + ClosedAdd,
|
||||
{
|
||||
fn from(matrix: &'a CooMatrix<T>) -> Self {
|
||||
convert_coo_csc(matrix)
|
||||
|
@ -77,7 +77,7 @@ where
|
|||
|
||||
impl<'a, T> From<&'a CscMatrix<T>> for CooMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero
|
||||
T: Scalar + Zero,
|
||||
{
|
||||
fn from(matrix: &'a CscMatrix<T>) -> Self {
|
||||
convert_csc_coo(matrix)
|
||||
|
@ -85,11 +85,11 @@ where
|
|||
}
|
||||
|
||||
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CscMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: Storage<T, R, C>
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: Storage<T, R, C>,
|
||||
{
|
||||
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
|
||||
convert_dense_csc(matrix)
|
||||
|
@ -97,8 +97,8 @@ impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CscMatrix<T>
|
|||
}
|
||||
|
||||
impl<'a, T> From<&'a CscMatrix<T>> for DMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero + ClosedAdd
|
||||
where
|
||||
T: Scalar + Zero + ClosedAdd,
|
||||
{
|
||||
fn from(matrix: &'a CscMatrix<T>) -> Self {
|
||||
convert_csc_dense(matrix)
|
||||
|
@ -106,8 +106,8 @@ impl<'a, T> From<&'a CscMatrix<T>> for DMatrix<T>
|
|||
}
|
||||
|
||||
impl<'a, T> From<&'a CscMatrix<T>> for CsrMatrix<T>
|
||||
where
|
||||
T: Scalar
|
||||
where
|
||||
T: Scalar,
|
||||
{
|
||||
fn from(matrix: &'a CscMatrix<T>) -> Self {
|
||||
convert_csc_csr(matrix)
|
||||
|
@ -116,7 +116,7 @@ impl<'a, T> From<&'a CscMatrix<T>> for CsrMatrix<T>
|
|||
|
||||
impl<'a, T> From<&'a CsrMatrix<T>> for CscMatrix<T>
|
||||
where
|
||||
T: Scalar
|
||||
T: Scalar,
|
||||
{
|
||||
fn from(matrix: &'a CsrMatrix<T>) -> Self {
|
||||
convert_csr_csc(matrix)
|
||||
|
|
|
@ -7,8 +7,8 @@ use std::ops::Add;
|
|||
|
||||
use num_traits::Zero;
|
||||
|
||||
use nalgebra::{ClosedAdd, Dim, DMatrix, Matrix, Scalar};
|
||||
use nalgebra::storage::Storage;
|
||||
use nalgebra::{ClosedAdd, DMatrix, Dim, Matrix, Scalar};
|
||||
|
||||
use crate::coo::CooMatrix;
|
||||
use crate::cs;
|
||||
|
@ -21,7 +21,7 @@ where
|
|||
T: Scalar + Zero,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: Storage<T, R, C>
|
||||
S: Storage<T, R, C>,
|
||||
{
|
||||
let mut coo = CooMatrix::new(dense.nrows(), dense.ncols());
|
||||
|
||||
|
@ -52,12 +52,14 @@ where
|
|||
/// Converts a [`CooMatrix`] to a [`CsrMatrix`].
|
||||
pub fn convert_coo_csr<T>(coo: &CooMatrix<T>) -> CsrMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero
|
||||
T: Scalar + Zero,
|
||||
{
|
||||
let (offsets, indices, values) = convert_coo_cs(coo.nrows(),
|
||||
coo.row_indices(),
|
||||
coo.col_indices(),
|
||||
coo.values());
|
||||
let (offsets, indices, values) = convert_coo_cs(
|
||||
coo.nrows(),
|
||||
coo.row_indices(),
|
||||
coo.col_indices(),
|
||||
coo.values(),
|
||||
);
|
||||
|
||||
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
|
||||
// to see if it can be justified for performance reasons)
|
||||
|
@ -66,8 +68,7 @@ where
|
|||
}
|
||||
|
||||
/// Converts a [`CsrMatrix`] to a [`CooMatrix`].
|
||||
pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T>
|
||||
{
|
||||
pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T> {
|
||||
let mut result = CooMatrix::new(csr.nrows(), csr.ncols());
|
||||
for (i, j, v) in csr.triplet_iter() {
|
||||
result.push(i, j, v.inlined_clone());
|
||||
|
@ -76,9 +77,9 @@ pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T>
|
|||
}
|
||||
|
||||
/// Converts a [`CsrMatrix`] to a dense matrix.
|
||||
pub fn convert_csr_dense<T>(csr:& CsrMatrix<T>) -> DMatrix<T>
|
||||
pub fn convert_csr_dense<T>(csr: &CsrMatrix<T>) -> DMatrix<T>
|
||||
where
|
||||
T: Scalar + ClosedAdd + Zero
|
||||
T: Scalar + ClosedAdd + Zero,
|
||||
{
|
||||
let mut output = DMatrix::zeros(csr.nrows(), csr.ncols());
|
||||
|
||||
|
@ -95,7 +96,7 @@ where
|
|||
T: Scalar + Zero,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: Storage<T, R, C>
|
||||
S: Storage<T, R, C>,
|
||||
{
|
||||
let mut row_offsets = Vec::with_capacity(dense.nrows() + 1);
|
||||
let mut col_idx = Vec::new();
|
||||
|
@ -105,8 +106,8 @@ where
|
|||
// nalgebra's column-major storage. The alternative would be to perform an initial sweep
|
||||
// to count number of non-zeros per row.
|
||||
row_offsets.push(0);
|
||||
for i in 0 .. dense.nrows() {
|
||||
for j in 0 .. dense.ncols() {
|
||||
for i in 0..dense.nrows() {
|
||||
for j in 0..dense.ncols() {
|
||||
let v = dense.index((i, j));
|
||||
if v != &T::zero() {
|
||||
col_idx.push(j);
|
||||
|
@ -125,12 +126,14 @@ where
|
|||
/// Converts a [`CooMatrix`] to a [`CscMatrix`].
|
||||
pub fn convert_coo_csc<T>(coo: &CooMatrix<T>) -> CscMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero
|
||||
T: Scalar + Zero,
|
||||
{
|
||||
let (offsets, indices, values) = convert_coo_cs(coo.ncols(),
|
||||
coo.col_indices(),
|
||||
coo.row_indices(),
|
||||
coo.values());
|
||||
let (offsets, indices, values) = convert_coo_cs(
|
||||
coo.ncols(),
|
||||
coo.col_indices(),
|
||||
coo.row_indices(),
|
||||
coo.values(),
|
||||
);
|
||||
|
||||
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
|
||||
// to see if it can be justified for performance reasons)
|
||||
|
@ -141,7 +144,7 @@ where
|
|||
/// Converts a [`CscMatrix`] to a [`CooMatrix`].
|
||||
pub fn convert_csc_coo<T>(csc: &CscMatrix<T>) -> CooMatrix<T>
|
||||
where
|
||||
T: Scalar
|
||||
T: Scalar,
|
||||
{
|
||||
let mut coo = CooMatrix::new(csc.nrows(), csc.ncols());
|
||||
for (i, j, v) in csc.triplet_iter() {
|
||||
|
@ -153,7 +156,7 @@ where
|
|||
/// Converts a [`CscMatrix`] to a dense matrix.
|
||||
pub fn convert_csc_dense<T>(csc: &CscMatrix<T>) -> DMatrix<T>
|
||||
where
|
||||
T: Scalar + ClosedAdd + Zero
|
||||
T: Scalar + ClosedAdd + Zero,
|
||||
{
|
||||
let mut output = DMatrix::zeros(csc.nrows(), csc.ncols());
|
||||
|
||||
|
@ -166,19 +169,19 @@ where
|
|||
|
||||
/// Converts a dense matrix to a [`CscMatrix`].
|
||||
pub fn convert_dense_csc<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CscMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: Storage<T, R, C>
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: Storage<T, R, C>,
|
||||
{
|
||||
let mut col_offsets = Vec::with_capacity(dense.ncols() + 1);
|
||||
let mut row_idx = Vec::new();
|
||||
let mut values = Vec::new();
|
||||
|
||||
col_offsets.push(0);
|
||||
for j in 0 .. dense.ncols() {
|
||||
for i in 0 .. dense.nrows() {
|
||||
for j in 0..dense.ncols() {
|
||||
for i in 0..dense.nrows() {
|
||||
let v = dense.index((i, j));
|
||||
if v != &T::zero() {
|
||||
row_idx.push(i);
|
||||
|
@ -197,13 +200,15 @@ pub fn convert_dense_csc<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CscMatrix<T>
|
|||
/// Converts a [`CsrMatrix`] to a [`CscMatrix`].
|
||||
pub fn convert_csr_csc<T>(csr: &CsrMatrix<T>) -> CscMatrix<T>
|
||||
where
|
||||
T: Scalar
|
||||
T: Scalar,
|
||||
{
|
||||
let (offsets, indices, values) = cs::transpose_cs(csr.nrows(),
|
||||
csr.ncols(),
|
||||
csr.row_offsets(),
|
||||
csr.col_indices(),
|
||||
csr.values());
|
||||
let (offsets, indices, values) = cs::transpose_cs(
|
||||
csr.nrows(),
|
||||
csr.ncols(),
|
||||
csr.row_offsets(),
|
||||
csr.col_indices(),
|
||||
csr.values(),
|
||||
);
|
||||
|
||||
// TODO: Avoid data validity check?
|
||||
CscMatrix::try_from_csc_data(csr.nrows(), csr.ncols(), offsets, indices, values)
|
||||
|
@ -212,27 +217,30 @@ where
|
|||
|
||||
/// Converts a [`CscMatrix`] to a [`CsrMatrix`].
|
||||
pub fn convert_csc_csr<T>(csc: &CscMatrix<T>) -> CsrMatrix<T>
|
||||
where
|
||||
T: Scalar
|
||||
where
|
||||
T: Scalar,
|
||||
{
|
||||
let (offsets, indices, values) = cs::transpose_cs(csc.ncols(),
|
||||
csc.nrows(),
|
||||
csc.col_offsets(),
|
||||
csc.row_indices(),
|
||||
csc.values());
|
||||
let (offsets, indices, values) = cs::transpose_cs(
|
||||
csc.ncols(),
|
||||
csc.nrows(),
|
||||
csc.col_offsets(),
|
||||
csc.row_indices(),
|
||||
csc.values(),
|
||||
);
|
||||
|
||||
// TODO: Avoid data validity check?
|
||||
CsrMatrix::try_from_csr_data(csc.nrows(), csc.ncols(), offsets, indices, values)
|
||||
.expect("Internal error: Invalid CSR data during CSC->CSR conversion")
|
||||
}
|
||||
|
||||
fn convert_coo_cs<T>(major_dim: usize,
|
||||
major_indices: &[usize],
|
||||
minor_indices: &[usize],
|
||||
values: &[T])
|
||||
-> (Vec<usize>, Vec<usize>, Vec<T>)
|
||||
fn convert_coo_cs<T>(
|
||||
major_dim: usize,
|
||||
major_indices: &[usize],
|
||||
minor_indices: &[usize],
|
||||
values: &[T],
|
||||
) -> (Vec<usize>, Vec<usize>, Vec<T>)
|
||||
where
|
||||
T: Scalar + Zero
|
||||
T: Scalar + Zero,
|
||||
{
|
||||
assert_eq!(major_indices.len(), minor_indices.len());
|
||||
assert_eq!(minor_indices.len(), values.len());
|
||||
|
|
|
@ -45,8 +45,7 @@ pub struct CooMatrix<T> {
|
|||
values: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T> CooMatrix<T>
|
||||
{
|
||||
impl<T> CooMatrix<T> {
|
||||
/// Construct a zero COO matrix of the given dimensions.
|
||||
///
|
||||
/// Specifically, the collection of triplets - corresponding to explicitly stored entries -
|
||||
|
@ -78,11 +77,13 @@ impl<T> CooMatrix<T>
|
|||
use crate::SparseFormatErrorKind::*;
|
||||
if row_indices.len() != col_indices.len() {
|
||||
return Err(SparseFormatError::from_kind_and_msg(
|
||||
InvalidStructure, "Number of row and col indices must be the same."
|
||||
InvalidStructure,
|
||||
"Number of row and col indices must be the same.",
|
||||
));
|
||||
} else if col_indices.len() != values.len() {
|
||||
return Err(SparseFormatError::from_kind_and_msg(
|
||||
InvalidStructure, "Number of col indices and values must be the same."
|
||||
InvalidStructure,
|
||||
"Number of col indices and values must be the same.",
|
||||
));
|
||||
}
|
||||
|
||||
|
@ -90,9 +91,15 @@ impl<T> CooMatrix<T>
|
|||
let col_indices_in_bounds = col_indices.iter().all(|j| *j < ncols);
|
||||
|
||||
if !row_indices_in_bounds {
|
||||
Err(SparseFormatError::from_kind_and_msg(IndexOutOfBounds, "Row index out of bounds."))
|
||||
Err(SparseFormatError::from_kind_and_msg(
|
||||
IndexOutOfBounds,
|
||||
"Row index out of bounds.",
|
||||
))
|
||||
} else if !col_indices_in_bounds {
|
||||
Err(SparseFormatError::from_kind_and_msg(IndexOutOfBounds, "Col index out of bounds."))
|
||||
Err(SparseFormatError::from_kind_and_msg(
|
||||
IndexOutOfBounds,
|
||||
"Col index out of bounds.",
|
||||
))
|
||||
} else {
|
||||
Ok(Self {
|
||||
nrows,
|
||||
|
|
|
@ -5,8 +5,8 @@ use num_traits::One;
|
|||
|
||||
use nalgebra::Scalar;
|
||||
|
||||
use crate::{SparseEntry, SparseEntryMut};
|
||||
use crate::pattern::SparsityPattern;
|
||||
use crate::{SparseEntry, SparseEntryMut};
|
||||
|
||||
/// An abstract compressed matrix.
|
||||
///
|
||||
|
@ -18,7 +18,7 @@ use crate::pattern::SparsityPattern;
|
|||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CsMatrix<T> {
|
||||
sparsity_pattern: SparsityPattern,
|
||||
values: Vec<T>
|
||||
values: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T> CsMatrix<T> {
|
||||
|
@ -50,14 +50,22 @@ impl<T> CsMatrix<T> {
|
|||
#[inline]
|
||||
pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) {
|
||||
let pattern = self.pattern();
|
||||
(pattern.major_offsets(), pattern.minor_indices(), &self.values)
|
||||
(
|
||||
pattern.major_offsets(),
|
||||
pattern.minor_indices(),
|
||||
&self.values,
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
||||
#[inline]
|
||||
pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||
let pattern = &mut self.sparsity_pattern;
|
||||
(pattern.major_offsets(), pattern.minor_indices(), &mut self.values)
|
||||
(
|
||||
pattern.major_offsets(),
|
||||
pattern.minor_indices(),
|
||||
&mut self.values,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -66,9 +74,12 @@ impl<T> CsMatrix<T> {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
pub fn from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
|
||||
-> Self {
|
||||
assert_eq!(pattern.nnz(), values.len(), "Internal error: consumers should verify shape compatibility.");
|
||||
pub fn from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>) -> Self {
|
||||
assert_eq!(
|
||||
pattern.nnz(),
|
||||
values.len(),
|
||||
"Internal error: consumers should verify shape compatibility."
|
||||
);
|
||||
Self {
|
||||
sparsity_pattern: pattern,
|
||||
values,
|
||||
|
@ -80,7 +91,7 @@ impl<T> CsMatrix<T> {
|
|||
pub fn get_index_range(&self, row_index: usize) -> Option<Range<usize>> {
|
||||
let row_begin = *self.sparsity_pattern.major_offsets().get(row_index)?;
|
||||
let row_end = *self.sparsity_pattern.major_offsets().get(row_index + 1)?;
|
||||
Some(row_begin .. row_end)
|
||||
Some(row_begin..row_end)
|
||||
}
|
||||
|
||||
pub fn take_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
|
||||
|
@ -105,13 +116,21 @@ impl<T> CsMatrix<T> {
|
|||
let (_, minor_indices, values) = self.cs_data();
|
||||
let minor_indices = &minor_indices[row_range.clone()];
|
||||
let values = &values[row_range];
|
||||
get_entry_from_slices(self.pattern().minor_dim(), minor_indices, values, minor_index)
|
||||
get_entry_from_slices(
|
||||
self.pattern().minor_dim(),
|
||||
minor_indices,
|
||||
values,
|
||||
minor_index,
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns a mutable entry for the given major/minor indices, or `None` if the indices are out
|
||||
/// of bounds.
|
||||
pub fn get_entry_mut(&mut self, major_index: usize, minor_index: usize)
|
||||
-> Option<SparseEntryMut<T>> {
|
||||
pub fn get_entry_mut(
|
||||
&mut self,
|
||||
major_index: usize,
|
||||
minor_index: usize,
|
||||
) -> Option<SparseEntryMut<T>> {
|
||||
let row_range = self.get_index_range(major_index)?;
|
||||
let minor_dim = self.pattern().minor_dim();
|
||||
let (_, minor_indices, values) = self.cs_data_mut();
|
||||
|
@ -126,7 +145,7 @@ impl<T> CsMatrix<T> {
|
|||
Some(CsLane {
|
||||
minor_indices: &minor_indices[range.clone()],
|
||||
values: &values[range],
|
||||
minor_dim: self.pattern().minor_dim()
|
||||
minor_dim: self.pattern().minor_dim(),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -138,7 +157,7 @@ impl<T> CsMatrix<T> {
|
|||
Some(CsLaneMut {
|
||||
minor_dim,
|
||||
minor_indices: &minor_indices[range.clone()],
|
||||
values: &mut values[range]
|
||||
values: &mut values[range],
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -156,7 +175,7 @@ impl<T> CsMatrix<T> {
|
|||
pub fn filter<P>(&self, predicate: P) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
P: Fn(usize, usize, &T) -> bool
|
||||
P: Fn(usize, usize, &T) -> bool,
|
||||
{
|
||||
let (major_dim, minor_dim) = (self.pattern().major_dim(), self.pattern().minor_dim());
|
||||
let mut new_offsets = Vec::with_capacity(self.pattern().major_dim() + 1);
|
||||
|
@ -180,16 +199,17 @@ impl<T> CsMatrix<T> {
|
|||
major_dim,
|
||||
minor_dim,
|
||||
new_offsets,
|
||||
new_indices)
|
||||
.expect("Internal error: Sparsity pattern must always be valid.");
|
||||
new_indices,
|
||||
)
|
||||
.expect("Internal error: Sparsity pattern must always be valid.");
|
||||
|
||||
Self::from_pattern_and_values(new_pattern, new_values)
|
||||
}
|
||||
|
||||
/// Returns the diagonal of the matrix as a sparse matrix.
|
||||
pub fn diagonal_as_matrix(&self) -> Self
|
||||
where
|
||||
T: Clone
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
// TODO: This might be faster with a binary search for each diagonal entry
|
||||
self.filter(|i, j, _| i == j)
|
||||
|
@ -199,13 +219,13 @@ impl<T> CsMatrix<T> {
|
|||
impl<T: Scalar + One> CsMatrix<T> {
|
||||
#[inline]
|
||||
pub fn identity(n: usize) -> Self {
|
||||
let offsets: Vec<_> = (0 ..= n).collect();
|
||||
let indices: Vec<_> = (0 .. n).collect();
|
||||
let offsets: Vec<_> = (0..=n).collect();
|
||||
let indices: Vec<_> = (0..n).collect();
|
||||
let values = vec![T::one(); n];
|
||||
|
||||
// TODO: We should skip checks here
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices)
|
||||
.unwrap();
|
||||
let pattern =
|
||||
SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices).unwrap();
|
||||
Self::from_pattern_and_values(pattern, values)
|
||||
}
|
||||
}
|
||||
|
@ -214,7 +234,8 @@ fn get_entry_from_slices<'a, T>(
|
|||
minor_dim: usize,
|
||||
minor_indices: &'a [usize],
|
||||
values: &'a [T],
|
||||
global_minor_index: usize) -> Option<SparseEntry<'a, T>> {
|
||||
global_minor_index: usize,
|
||||
) -> Option<SparseEntry<'a, T>> {
|
||||
let local_index = minor_indices.binary_search(&global_minor_index);
|
||||
if let Ok(local_index) = local_index {
|
||||
Some(SparseEntry::NonZero(&values[local_index]))
|
||||
|
@ -229,7 +250,8 @@ fn get_mut_entry_from_slices<'a, T>(
|
|||
minor_dim: usize,
|
||||
minor_indices: &'a [usize],
|
||||
values: &'a mut [T],
|
||||
global_minor_indices: usize) -> Option<SparseEntryMut<'a, T>> {
|
||||
global_minor_indices: usize,
|
||||
) -> Option<SparseEntryMut<'a, T>> {
|
||||
let local_index = minor_indices.binary_search(&global_minor_indices);
|
||||
if let Ok(local_index) = local_index {
|
||||
Some(SparseEntryMut::NonZero(&mut values[local_index]))
|
||||
|
@ -244,14 +266,14 @@ fn get_mut_entry_from_slices<'a, T>(
|
|||
pub struct CsLane<'a, T> {
|
||||
minor_dim: usize,
|
||||
minor_indices: &'a [usize],
|
||||
values: &'a [T]
|
||||
values: &'a [T],
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct CsLaneMut<'a, T> {
|
||||
minor_dim: usize,
|
||||
minor_indices: &'a [usize],
|
||||
values: &'a mut [T]
|
||||
values: &'a mut [T],
|
||||
}
|
||||
|
||||
pub struct CsLaneIter<'a, T> {
|
||||
|
@ -266,14 +288,14 @@ impl<'a, T> CsLaneIter<'a, T> {
|
|||
Self {
|
||||
current_lane_idx: 0,
|
||||
pattern,
|
||||
remaining_values: values
|
||||
remaining_values: values,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CsLaneIter<'a, T>
|
||||
where
|
||||
T: 'a
|
||||
where
|
||||
T: 'a,
|
||||
{
|
||||
type Item = CsLane<'a, T>;
|
||||
|
||||
|
@ -284,13 +306,13 @@ impl<'a, T> Iterator for CsLaneIter<'a, T>
|
|||
if let Some(minor_indices) = lane {
|
||||
let count = minor_indices.len();
|
||||
let values_in_lane = &self.remaining_values[..count];
|
||||
self.remaining_values = &self.remaining_values[count ..];
|
||||
self.remaining_values = &self.remaining_values[count..];
|
||||
self.current_lane_idx += 1;
|
||||
|
||||
Some(CsLane {
|
||||
minor_dim,
|
||||
minor_indices,
|
||||
values: values_in_lane
|
||||
values: values_in_lane,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
|
@ -310,14 +332,14 @@ impl<'a, T> CsLaneIterMut<'a, T> {
|
|||
Self {
|
||||
current_lane_idx: 0,
|
||||
pattern,
|
||||
remaining_values: values
|
||||
remaining_values: values,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CsLaneIterMut<'a, T>
|
||||
where
|
||||
T: 'a
|
||||
where
|
||||
T: 'a,
|
||||
{
|
||||
type Item = CsLaneMut<'a, T>;
|
||||
|
||||
|
@ -336,7 +358,7 @@ impl<'a, T> Iterator for CsLaneIterMut<'a, T>
|
|||
Some(CsLaneMut {
|
||||
minor_dim,
|
||||
minor_indices,
|
||||
values: values_in_lane
|
||||
values: values_in_lane,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
|
@ -375,10 +397,11 @@ macro_rules! impl_cs_lane_common_methods {
|
|||
self.minor_dim,
|
||||
self.minor_indices,
|
||||
self.values,
|
||||
global_col_index)
|
||||
global_col_index,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_cs_lane_common_methods!(CsLane<'a, T>);
|
||||
|
@ -394,10 +417,12 @@ impl<'a, T> CsLaneMut<'a, T> {
|
|||
}
|
||||
|
||||
pub fn get_entry_mut(&mut self, global_minor_index: usize) -> Option<SparseEntryMut<T>> {
|
||||
get_mut_entry_from_slices(self.minor_dim,
|
||||
self.minor_indices,
|
||||
self.values,
|
||||
global_minor_index)
|
||||
get_mut_entry_from_slices(
|
||||
self.minor_dim,
|
||||
self.minor_indices,
|
||||
self.values,
|
||||
global_minor_index,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -405,7 +430,7 @@ impl<'a, T> CsLaneMut<'a, T> {
|
|||
/// TODO: This doesn't belong here.
|
||||
struct UninitVec<T> {
|
||||
vec: Vec<T>,
|
||||
len: usize
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl<T> UninitVec<T> {
|
||||
|
@ -414,7 +439,7 @@ impl<T> UninitVec<T> {
|
|||
vec: Vec::with_capacity(len),
|
||||
// We need to store len separately, because for zero-sized types,
|
||||
// Vec::with_capacity(len) does not give vec.capacity() == len
|
||||
len
|
||||
len,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -440,14 +465,14 @@ impl<T> UninitVec<T> {
|
|||
/// This means that major and minor roles are switched. This is used for converting between CSR
|
||||
/// and CSC formats.
|
||||
pub fn transpose_cs<T>(
|
||||
major_dim: usize,
|
||||
minor_dim: usize,
|
||||
source_major_offsets: &[usize],
|
||||
source_minor_indices: &[usize],
|
||||
values: &[T])
|
||||
-> (Vec<usize>, Vec<usize>, Vec<T>)
|
||||
major_dim: usize,
|
||||
minor_dim: usize,
|
||||
source_major_offsets: &[usize],
|
||||
source_minor_indices: &[usize],
|
||||
values: &[T],
|
||||
) -> (Vec<usize>, Vec<usize>, Vec<T>)
|
||||
where
|
||||
T: Scalar
|
||||
T: Scalar,
|
||||
{
|
||||
assert_eq!(source_major_offsets.len(), major_dim + 1);
|
||||
assert_eq!(source_minor_indices.len(), values.len());
|
||||
|
@ -470,18 +495,20 @@ where
|
|||
// Keep track of how many entries we have placed in each target major lane
|
||||
let mut current_target_major_counts = vec![0; minor_dim];
|
||||
|
||||
for source_major_idx in 0 .. major_dim {
|
||||
for source_major_idx in 0..major_dim {
|
||||
let source_lane_begin = source_major_offsets[source_major_idx];
|
||||
let source_lane_end = source_major_offsets[source_major_idx + 1];
|
||||
let source_lane_indices = &source_minor_indices[source_lane_begin .. source_lane_end];
|
||||
let source_lane_values = &values[source_lane_begin .. source_lane_end];
|
||||
let source_lane_indices = &source_minor_indices[source_lane_begin..source_lane_end];
|
||||
let source_lane_values = &values[source_lane_begin..source_lane_end];
|
||||
|
||||
for (&source_minor_idx, val) in source_lane_indices.iter().zip(source_lane_values) {
|
||||
// Compute the offset in the target data for this particular source entry
|
||||
let target_lane_count = &mut current_target_major_counts[source_minor_idx];
|
||||
let target_lane_count = &mut current_target_major_counts[source_minor_idx];
|
||||
let entry_offset = target_offsets[source_minor_idx] + *target_lane_count;
|
||||
target_indices[entry_offset] = source_major_idx;
|
||||
unsafe { target_values.set(entry_offset, val.inlined_clone()); }
|
||||
unsafe {
|
||||
target_values.set(entry_offset, val.inlined_clone());
|
||||
}
|
||||
*target_lane_count += 1;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,14 +3,14 @@
|
|||
//! This is the module-level documentation. See [`CscMatrix`] for the main documentation of the
|
||||
//! CSC implementation.
|
||||
|
||||
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
|
||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||
use crate::csr::CsrMatrix;
|
||||
use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut};
|
||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||
|
||||
use std::slice::{IterMut, Iter};
|
||||
use num_traits::{One};
|
||||
use nalgebra::Scalar;
|
||||
use num_traits::One;
|
||||
use std::slice::{Iter, IterMut};
|
||||
|
||||
/// A CSC representation of a sparse matrix.
|
||||
///
|
||||
|
@ -130,7 +130,7 @@ impl<T> CscMatrix<T> {
|
|||
/// Create a zero CSC matrix with no explicitly stored entries.
|
||||
pub fn zeros(nrows: usize, ncols: usize) -> Self {
|
||||
Self {
|
||||
cs: CsMatrix::new(ncols, nrows)
|
||||
cs: CsMatrix::new(ncols, nrows),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -196,8 +196,12 @@ impl<T> CscMatrix<T> {
|
|||
values: Vec<T>,
|
||||
) -> Result<Self, SparseFormatError> {
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||
num_cols, num_rows, col_offsets, row_indices)
|
||||
.map_err(pattern_format_error_to_csc_error)?;
|
||||
num_cols,
|
||||
num_rows,
|
||||
col_offsets,
|
||||
row_indices,
|
||||
)
|
||||
.map_err(pattern_format_error_to_csc_error)?;
|
||||
Self::try_from_pattern_and_values(pattern, values)
|
||||
}
|
||||
|
||||
|
@ -205,16 +209,19 @@ impl<T> CscMatrix<T> {
|
|||
///
|
||||
/// Returns an error if the number of values does not match the number of minor indices
|
||||
/// in the pattern.
|
||||
pub fn try_from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
|
||||
-> Result<Self, SparseFormatError> {
|
||||
pub fn try_from_pattern_and_values(
|
||||
pattern: SparsityPattern,
|
||||
values: Vec<T>,
|
||||
) -> Result<Self, SparseFormatError> {
|
||||
if pattern.nnz() == values.len() {
|
||||
Ok(Self {
|
||||
cs: CsMatrix::from_pattern_and_values(pattern, values)
|
||||
cs: CsMatrix::from_pattern_and_values(pattern, values),
|
||||
})
|
||||
} else {
|
||||
Err(SparseFormatError::from_kind_and_msg(
|
||||
SparseFormatErrorKind::InvalidStructure,
|
||||
"Number of values and row indices must be the same"))
|
||||
"Number of values and row indices must be the same",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -239,7 +246,7 @@ impl<T> CscMatrix<T> {
|
|||
pub fn triplet_iter(&self) -> CscTripletIter<T> {
|
||||
CscTripletIter {
|
||||
pattern_iter: self.pattern().entries(),
|
||||
values_iter: self.values().iter()
|
||||
values_iter: self.values().iter(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -270,7 +277,7 @@ impl<T> CscMatrix<T> {
|
|||
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||
CscTripletIterMut {
|
||||
pattern_iter: pattern.entries(),
|
||||
values_mut_iter: values.iter_mut()
|
||||
values_mut_iter: values.iter_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -281,8 +288,7 @@ impl<T> CscMatrix<T> {
|
|||
/// Panics if column index is out of bounds.
|
||||
#[inline]
|
||||
pub fn col(&self, index: usize) -> CscCol<T> {
|
||||
self.get_col(index)
|
||||
.expect("Row index must be in bounds")
|
||||
self.get_col(index).expect("Row index must be in bounds")
|
||||
}
|
||||
|
||||
/// Mutable column access for the given column index.
|
||||
|
@ -299,23 +305,19 @@ impl<T> CscMatrix<T> {
|
|||
/// Return the column at the given column index, or `None` if out of bounds.
|
||||
#[inline]
|
||||
pub fn get_col(&self, index: usize) -> Option<CscCol<T>> {
|
||||
self.cs
|
||||
.get_lane(index)
|
||||
.map(|lane| CscCol { lane })
|
||||
self.cs.get_lane(index).map(|lane| CscCol { lane })
|
||||
}
|
||||
|
||||
/// Mutable column access for the given column index, or `None` if out of bounds.
|
||||
#[inline]
|
||||
pub fn get_col_mut(&mut self, index: usize) -> Option<CscColMut<T>> {
|
||||
self.cs
|
||||
.get_lane_mut(index)
|
||||
.map(|lane| CscColMut { lane })
|
||||
self.cs.get_lane_mut(index).map(|lane| CscColMut { lane })
|
||||
}
|
||||
|
||||
/// An iterator over columns in the matrix.
|
||||
pub fn col_iter(&self) -> CscColIter<T> {
|
||||
CscColIter {
|
||||
lane_iter: CsLaneIter::new(self.pattern(), self.values())
|
||||
lane_iter: CsLaneIter::new(self.pattern(), self.values()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -323,7 +325,7 @@ impl<T> CscMatrix<T> {
|
|||
pub fn col_iter_mut(&mut self) -> CscColIterMut<T> {
|
||||
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||
CscColIterMut {
|
||||
lane_iter: CsLaneIterMut::new(pattern, values)
|
||||
lane_iter: CsLaneIterMut::new(pattern, values),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -397,8 +399,11 @@ impl<T> CscMatrix<T> {
|
|||
///
|
||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||
/// stored row entries for the given column.
|
||||
pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize)
|
||||
-> Option<SparseEntryMut<T>> {
|
||||
pub fn get_entry_mut(
|
||||
&mut self,
|
||||
row_index: usize,
|
||||
col_index: usize,
|
||||
) -> Option<SparseEntryMut<T>> {
|
||||
self.cs.get_entry_mut(col_index, row_index)
|
||||
}
|
||||
|
||||
|
@ -444,11 +449,15 @@ impl<T> CscMatrix<T> {
|
|||
pub fn filter<P>(&self, predicate: P) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
P: Fn(usize, usize, &T) -> bool
|
||||
P: Fn(usize, usize, &T) -> bool,
|
||||
{
|
||||
// Note: Predicate uses (row, col, value), so we have to switch around since
|
||||
// cs uses (major, minor, value)
|
||||
Self { cs: self.cs.filter(|col_idx, row_idx, v| predicate(row_idx, col_idx, v)) }
|
||||
Self {
|
||||
cs: self
|
||||
.cs
|
||||
.filter(|col_idx, row_idx, v| predicate(row_idx, col_idx, v)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new matrix representing the upper triangular part of this matrix.
|
||||
|
@ -456,7 +465,7 @@ impl<T> CscMatrix<T> {
|
|||
/// The result includes the diagonal of the matrix.
|
||||
pub fn upper_triangle(&self) -> Self
|
||||
where
|
||||
T: Clone
|
||||
T: Clone,
|
||||
{
|
||||
self.filter(|i, j, _| i <= j)
|
||||
}
|
||||
|
@ -466,7 +475,7 @@ impl<T> CscMatrix<T> {
|
|||
/// The result includes the diagonal of the matrix.
|
||||
pub fn lower_triangle(&self) -> Self
|
||||
where
|
||||
T: Clone
|
||||
T: Clone,
|
||||
{
|
||||
self.filter(|i, j, _| i >= j)
|
||||
}
|
||||
|
@ -474,15 +483,17 @@ impl<T> CscMatrix<T> {
|
|||
/// Returns the diagonal of the matrix as a sparse matrix.
|
||||
pub fn diagonal_as_csc(&self) -> Self
|
||||
where
|
||||
T: Clone
|
||||
T: Clone,
|
||||
{
|
||||
Self { cs: self.cs.diagonal_as_matrix() }
|
||||
Self {
|
||||
cs: self.cs.diagonal_as_matrix(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> CscMatrix<T>
|
||||
where
|
||||
T: Scalar
|
||||
where
|
||||
T: Scalar,
|
||||
{
|
||||
/// Compute the transpose of the matrix.
|
||||
pub fn transpose(&self) -> CscMatrix<T> {
|
||||
|
@ -495,7 +506,7 @@ impl<T: Scalar + One> CscMatrix<T> {
|
|||
#[inline]
|
||||
pub fn identity(n: usize) -> Self {
|
||||
Self {
|
||||
cs: CsMatrix::identity(n)
|
||||
cs: CsMatrix::identity(n),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -505,30 +516,34 @@ impl<T: Scalar + One> CscMatrix<T> {
|
|||
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
||||
/// not lanes, major and minor dimensions.
|
||||
fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseFormatError {
|
||||
use SparsityPatternFormatError::*;
|
||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||
use SparseFormatError as E;
|
||||
use SparseFormatErrorKind as K;
|
||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||
use SparsityPatternFormatError::*;
|
||||
|
||||
match err {
|
||||
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Length of col offset array is not equal to ncols + 1."),
|
||||
"Length of col offset array is not equal to ncols + 1.",
|
||||
),
|
||||
InvalidOffsetFirstLast => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"First or last col offset is inconsistent with format specification."),
|
||||
"First or last col offset is inconsistent with format specification.",
|
||||
),
|
||||
NonmonotonicOffsets => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Col offsets are not monotonically increasing."),
|
||||
"Col offsets are not monotonically increasing.",
|
||||
),
|
||||
NonmonotonicMinorIndices => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Row indices are not monotonically increasing (sorted) within each column."),
|
||||
MinorIndexOutOfBounds => E::from_kind_and_msg(
|
||||
K::IndexOutOfBounds,
|
||||
"Row indices are out of bounds."),
|
||||
PatternDuplicateEntry => E::from_kind_and_msg(
|
||||
K::DuplicateEntry,
|
||||
"Matrix data contains duplicate entries."),
|
||||
"Row indices are not monotonically increasing (sorted) within each column.",
|
||||
),
|
||||
MinorIndexOutOfBounds => {
|
||||
E::from_kind_and_msg(K::IndexOutOfBounds, "Row indices are out of bounds.")
|
||||
}
|
||||
PatternDuplicateEntry => {
|
||||
E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -536,7 +551,7 @@ fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseF
|
|||
#[derive(Debug)]
|
||||
pub struct CscTripletIter<'a, T> {
|
||||
pattern_iter: SparsityPatternIter<'a>,
|
||||
values_iter: Iter<'a, T>
|
||||
values_iter: Iter<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T: Clone> CscTripletIter<'a, T> {
|
||||
|
@ -545,7 +560,7 @@ impl<'a, T: Clone> CscTripletIter<'a, T> {
|
|||
/// The triplet iterator returns references to the values. This method adapts the iterator
|
||||
/// so that the values are cloned.
|
||||
#[inline]
|
||||
pub fn cloned_values(self) -> impl 'a + Iterator<Item=(usize, usize, T)> {
|
||||
pub fn cloned_values(self) -> impl 'a + Iterator<Item = (usize, usize, T)> {
|
||||
self.map(|(i, j, v)| (i, j, v.clone()))
|
||||
}
|
||||
}
|
||||
|
@ -559,7 +574,7 @@ impl<'a, T> Iterator for CscTripletIter<'a, T> {
|
|||
|
||||
match (next_entry, next_value) {
|
||||
(Some((i, j)), Some(v)) => Some((j, i, v)),
|
||||
_ => None
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -568,7 +583,7 @@ impl<'a, T> Iterator for CscTripletIter<'a, T> {
|
|||
#[derive(Debug)]
|
||||
pub struct CscTripletIterMut<'a, T> {
|
||||
pattern_iter: SparsityPatternIter<'a>,
|
||||
values_mut_iter: IterMut<'a, T>
|
||||
values_mut_iter: IterMut<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
|
||||
|
@ -581,7 +596,7 @@ impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
|
|||
|
||||
match (next_entry, next_value) {
|
||||
(Some((i, j)), Some(v)) => Some((j, i, v)),
|
||||
_ => None
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -589,7 +604,7 @@ impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
|
|||
/// An immutable representation of a column in a CSC matrix.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CscCol<'a, T> {
|
||||
lane: CsLane<'a, T>
|
||||
lane: CsLane<'a, T>,
|
||||
}
|
||||
|
||||
/// A mutable representation of a column in a CSC matrix.
|
||||
|
@ -598,7 +613,7 @@ pub struct CscCol<'a, T> {
|
|||
/// to the column cannot be modified.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct CscColMut<'a, T> {
|
||||
lane: CsLaneMut<'a, T>
|
||||
lane: CsLaneMut<'a, T>,
|
||||
}
|
||||
|
||||
/// Implement the methods common to both CscCol and CscColMut
|
||||
|
@ -637,7 +652,7 @@ macro_rules! impl_csc_col_common_methods {
|
|||
self.lane.get_entry(global_row_index)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_csc_col_common_methods!(CscCol<'a, T>);
|
||||
|
@ -666,33 +681,29 @@ impl<'a, T> CscColMut<'a, T> {
|
|||
|
||||
/// Column iterator for [CscMatrix](struct.CscMatrix.html).
|
||||
pub struct CscColIter<'a, T> {
|
||||
lane_iter: CsLaneIter<'a, T>
|
||||
lane_iter: CsLaneIter<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CscColIter<'a, T> {
|
||||
type Item = CscCol<'a, T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.lane_iter
|
||||
.next()
|
||||
.map(|lane| CscCol { lane })
|
||||
self.lane_iter.next().map(|lane| CscCol { lane })
|
||||
}
|
||||
}
|
||||
|
||||
/// Mutable column iterator for [CscMatrix](struct.CscMatrix.html).
|
||||
pub struct CscColIterMut<'a, T> {
|
||||
lane_iter: CsLaneIterMut<'a, T>
|
||||
lane_iter: CsLaneIterMut<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CscColIterMut<'a, T>
|
||||
where
|
||||
T: 'a
|
||||
T: 'a,
|
||||
{
|
||||
type Item = CscColMut<'a, T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.lane_iter
|
||||
.next()
|
||||
.map(|lane| CscColMut { lane })
|
||||
self.lane_iter.next().map(|lane| CscColMut { lane })
|
||||
}
|
||||
}
|
|
@ -2,15 +2,15 @@
|
|||
//!
|
||||
//! This is the module-level documentation. See [`CsrMatrix`] for the main documentation of the
|
||||
//! CSC implementation.
|
||||
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
|
||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||
use crate::csc::CscMatrix;
|
||||
use crate::cs::{CsMatrix, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut};
|
||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||
|
||||
use nalgebra::Scalar;
|
||||
use num_traits::{One};
|
||||
use num_traits::One;
|
||||
|
||||
use std::slice::{IterMut, Iter};
|
||||
use std::slice::{Iter, IterMut};
|
||||
|
||||
/// A CSR representation of a sparse matrix.
|
||||
///
|
||||
|
@ -130,7 +130,7 @@ impl<T> CsrMatrix<T> {
|
|||
/// Create a zero CSR matrix with no explicitly stored entries.
|
||||
pub fn zeros(nrows: usize, ncols: usize) -> Self {
|
||||
Self {
|
||||
cs: CsMatrix::new(nrows, ncols)
|
||||
cs: CsMatrix::new(nrows, ncols),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -198,8 +198,12 @@ impl<T> CsrMatrix<T> {
|
|||
values: Vec<T>,
|
||||
) -> Result<Self, SparseFormatError> {
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||
num_rows, num_cols, row_offsets, col_indices)
|
||||
.map_err(pattern_format_error_to_csr_error)?;
|
||||
num_rows,
|
||||
num_cols,
|
||||
row_offsets,
|
||||
col_indices,
|
||||
)
|
||||
.map_err(pattern_format_error_to_csr_error)?;
|
||||
Self::try_from_pattern_and_values(pattern, values)
|
||||
}
|
||||
|
||||
|
@ -207,16 +211,19 @@ impl<T> CsrMatrix<T> {
|
|||
///
|
||||
/// Returns an error if the number of values does not match the number of minor indices
|
||||
/// in the pattern.
|
||||
pub fn try_from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
|
||||
-> Result<Self, SparseFormatError> {
|
||||
pub fn try_from_pattern_and_values(
|
||||
pattern: SparsityPattern,
|
||||
values: Vec<T>,
|
||||
) -> Result<Self, SparseFormatError> {
|
||||
if pattern.nnz() == values.len() {
|
||||
Ok(Self {
|
||||
cs: CsMatrix::from_pattern_and_values(pattern, values)
|
||||
cs: CsMatrix::from_pattern_and_values(pattern, values),
|
||||
})
|
||||
} else {
|
||||
Err(SparseFormatError::from_kind_and_msg(
|
||||
SparseFormatErrorKind::InvalidStructure,
|
||||
"Number of values and column indices must be the same"))
|
||||
"Number of values and column indices must be the same",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -241,7 +248,7 @@ impl<T> CsrMatrix<T> {
|
|||
pub fn triplet_iter(&self) -> CsrTripletIter<T> {
|
||||
CsrTripletIter {
|
||||
pattern_iter: self.pattern().entries(),
|
||||
values_iter: self.values().iter()
|
||||
values_iter: self.values().iter(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -272,7 +279,7 @@ impl<T> CsrMatrix<T> {
|
|||
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||
CsrTripletIterMut {
|
||||
pattern_iter: pattern.entries(),
|
||||
values_mut_iter: values.iter_mut()
|
||||
values_mut_iter: values.iter_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -283,8 +290,7 @@ impl<T> CsrMatrix<T> {
|
|||
/// Panics if row index is out of bounds.
|
||||
#[inline]
|
||||
pub fn row(&self, index: usize) -> CsrRow<T> {
|
||||
self.get_row(index)
|
||||
.expect("Row index must be in bounds")
|
||||
self.get_row(index).expect("Row index must be in bounds")
|
||||
}
|
||||
|
||||
/// Mutable row access for the given row index.
|
||||
|
@ -301,23 +307,19 @@ impl<T> CsrMatrix<T> {
|
|||
/// Return the row at the given row index, or `None` if out of bounds.
|
||||
#[inline]
|
||||
pub fn get_row(&self, index: usize) -> Option<CsrRow<T>> {
|
||||
self.cs
|
||||
.get_lane(index)
|
||||
.map(|lane| CsrRow { lane })
|
||||
self.cs.get_lane(index).map(|lane| CsrRow { lane })
|
||||
}
|
||||
|
||||
/// Mutable row access for the given row index, or `None` if out of bounds.
|
||||
#[inline]
|
||||
pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<T>> {
|
||||
self.cs
|
||||
.get_lane_mut(index)
|
||||
.map(|lane| CsrRowMut { lane })
|
||||
self.cs.get_lane_mut(index).map(|lane| CsrRowMut { lane })
|
||||
}
|
||||
|
||||
/// An iterator over rows in the matrix.
|
||||
pub fn row_iter(&self) -> CsrRowIter<T> {
|
||||
CsrRowIter {
|
||||
lane_iter: CsLaneIter::new(self.pattern(), self.values())
|
||||
lane_iter: CsLaneIter::new(self.pattern(), self.values()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -399,8 +401,11 @@ impl<T> CsrMatrix<T> {
|
|||
///
|
||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||
/// stored column entries for the given row.
|
||||
pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize)
|
||||
-> Option<SparseEntryMut<T>> {
|
||||
pub fn get_entry_mut(
|
||||
&mut self,
|
||||
row_index: usize,
|
||||
col_index: usize,
|
||||
) -> Option<SparseEntryMut<T>> {
|
||||
self.cs.get_entry_mut(row_index, col_index)
|
||||
}
|
||||
|
||||
|
@ -444,19 +449,23 @@ impl<T> CsrMatrix<T> {
|
|||
/// Creates a sparse matrix that contains only the explicit entries decided by the
|
||||
/// given predicate.
|
||||
pub fn filter<P>(&self, predicate: P) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
P: Fn(usize, usize, &T) -> bool
|
||||
where
|
||||
T: Clone,
|
||||
P: Fn(usize, usize, &T) -> bool,
|
||||
{
|
||||
Self { cs: self.cs.filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)) }
|
||||
Self {
|
||||
cs: self
|
||||
.cs
|
||||
.filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new matrix representing the upper triangular part of this matrix.
|
||||
///
|
||||
/// The result includes the diagonal of the matrix.
|
||||
pub fn upper_triangle(&self) -> Self
|
||||
where
|
||||
T: Clone
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
self.filter(|i, j, _| i <= j)
|
||||
}
|
||||
|
@ -465,24 +474,26 @@ impl<T> CsrMatrix<T> {
|
|||
///
|
||||
/// The result includes the diagonal of the matrix.
|
||||
pub fn lower_triangle(&self) -> Self
|
||||
where
|
||||
T: Clone
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
self.filter(|i, j, _| i >= j)
|
||||
}
|
||||
|
||||
/// Returns the diagonal of the matrix as a sparse matrix.
|
||||
pub fn diagonal_as_csr(&self) -> Self
|
||||
where
|
||||
T: Clone
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
Self { cs: self.cs.diagonal_as_matrix() }
|
||||
Self {
|
||||
cs: self.cs.diagonal_as_matrix(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> CsrMatrix<T>
|
||||
where
|
||||
T: Scalar
|
||||
T: Scalar,
|
||||
{
|
||||
/// Compute the transpose of the matrix.
|
||||
pub fn transpose(&self) -> CsrMatrix<T> {
|
||||
|
@ -495,7 +506,7 @@ impl<T: Scalar + One> CsrMatrix<T> {
|
|||
#[inline]
|
||||
pub fn identity(n: usize) -> Self {
|
||||
Self {
|
||||
cs: CsMatrix::identity(n)
|
||||
cs: CsMatrix::identity(n),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -505,30 +516,34 @@ impl<T: Scalar + One> CsrMatrix<T> {
|
|||
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
||||
/// not lanes, major and minor dimensions.
|
||||
fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseFormatError {
|
||||
use SparsityPatternFormatError::*;
|
||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||
use SparseFormatError as E;
|
||||
use SparseFormatErrorKind as K;
|
||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||
use SparsityPatternFormatError::*;
|
||||
|
||||
match err {
|
||||
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Length of row offset array is not equal to nrows + 1."),
|
||||
"Length of row offset array is not equal to nrows + 1.",
|
||||
),
|
||||
InvalidOffsetFirstLast => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"First or last row offset is inconsistent with format specification."),
|
||||
"First or last row offset is inconsistent with format specification.",
|
||||
),
|
||||
NonmonotonicOffsets => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Row offsets are not monotonically increasing."),
|
||||
"Row offsets are not monotonically increasing.",
|
||||
),
|
||||
NonmonotonicMinorIndices => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Column indices are not monotonically increasing (sorted) within each row."),
|
||||
MinorIndexOutOfBounds => E::from_kind_and_msg(
|
||||
K::IndexOutOfBounds,
|
||||
"Column indices are out of bounds."),
|
||||
PatternDuplicateEntry => E::from_kind_and_msg(
|
||||
K::DuplicateEntry,
|
||||
"Matrix data contains duplicate entries."),
|
||||
"Column indices are not monotonically increasing (sorted) within each row.",
|
||||
),
|
||||
MinorIndexOutOfBounds => {
|
||||
E::from_kind_and_msg(K::IndexOutOfBounds, "Column indices are out of bounds.")
|
||||
}
|
||||
PatternDuplicateEntry => {
|
||||
E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -536,7 +551,7 @@ fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseF
|
|||
#[derive(Debug)]
|
||||
pub struct CsrTripletIter<'a, T> {
|
||||
pattern_iter: SparsityPatternIter<'a>,
|
||||
values_iter: Iter<'a, T>
|
||||
values_iter: Iter<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T: Clone> CsrTripletIter<'a, T> {
|
||||
|
@ -545,7 +560,7 @@ impl<'a, T: Clone> CsrTripletIter<'a, T> {
|
|||
/// The triplet iterator returns references to the values. This method adapts the iterator
|
||||
/// so that the values are cloned.
|
||||
#[inline]
|
||||
pub fn cloned_values(self) -> impl 'a + Iterator<Item=(usize, usize, T)> {
|
||||
pub fn cloned_values(self) -> impl 'a + Iterator<Item = (usize, usize, T)> {
|
||||
self.map(|(i, j, v)| (i, j, v.clone()))
|
||||
}
|
||||
}
|
||||
|
@ -559,7 +574,7 @@ impl<'a, T> Iterator for CsrTripletIter<'a, T> {
|
|||
|
||||
match (next_entry, next_value) {
|
||||
(Some((i, j)), Some(v)) => Some((i, j, v)),
|
||||
_ => None
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -568,7 +583,7 @@ impl<'a, T> Iterator for CsrTripletIter<'a, T> {
|
|||
#[derive(Debug)]
|
||||
pub struct CsrTripletIterMut<'a, T> {
|
||||
pattern_iter: SparsityPatternIter<'a>,
|
||||
values_mut_iter: IterMut<'a, T>
|
||||
values_mut_iter: IterMut<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
|
||||
|
@ -581,7 +596,7 @@ impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
|
|||
|
||||
match (next_entry, next_value) {
|
||||
(Some((i, j)), Some(v)) => Some((i, j, v)),
|
||||
_ => None
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -589,7 +604,7 @@ impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
|
|||
/// An immutable representation of a row in a CSR matrix.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CsrRow<'a, T> {
|
||||
lane: CsLane<'a, T>
|
||||
lane: CsLane<'a, T>,
|
||||
}
|
||||
|
||||
/// A mutable representation of a row in a CSR matrix.
|
||||
|
@ -598,7 +613,7 @@ pub struct CsrRow<'a, T> {
|
|||
/// to the row cannot be modified.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct CsrRowMut<'a, T> {
|
||||
lane: CsLaneMut<'a, T>
|
||||
lane: CsLaneMut<'a, T>,
|
||||
}
|
||||
|
||||
/// Implement the methods common to both CsrRow and CsrRowMut
|
||||
|
@ -638,7 +653,7 @@ macro_rules! impl_csr_row_common_methods {
|
|||
self.lane.get_entry(global_col_index)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_csr_row_common_methods!(CsrRow<'a, T>);
|
||||
|
@ -670,33 +685,29 @@ impl<'a, T> CsrRowMut<'a, T> {
|
|||
|
||||
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
||||
pub struct CsrRowIter<'a, T> {
|
||||
lane_iter: CsLaneIter<'a, T>
|
||||
lane_iter: CsLaneIter<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CsrRowIter<'a, T> {
|
||||
type Item = CsrRow<'a, T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.lane_iter
|
||||
.next()
|
||||
.map(|lane| CsrRow { lane })
|
||||
self.lane_iter.next().map(|lane| CsrRow { lane })
|
||||
}
|
||||
}
|
||||
|
||||
/// Mutable row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
||||
pub struct CsrRowIterMut<'a, T> {
|
||||
lane_iter: CsLaneIterMut<'a, T>
|
||||
lane_iter: CsLaneIterMut<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CsrRowIterMut<'a, T>
|
||||
where
|
||||
T: 'a
|
||||
T: 'a,
|
||||
{
|
||||
type Item = CsrRowMut<'a, T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.lane_iter
|
||||
.next()
|
||||
.map(|lane| CsrRowMut { lane })
|
||||
self.lane_iter.next().map(|lane| CsrRowMut { lane })
|
||||
}
|
||||
}
|
|
@ -1,10 +1,10 @@
|
|||
use crate::pattern::SparsityPattern;
|
||||
use crate::csc::CscMatrix;
|
||||
use core::{mem, iter};
|
||||
use nalgebra::{Scalar, RealField, DMatrixSlice, DMatrixSliceMut, DMatrix};
|
||||
use std::fmt::{Display, Formatter};
|
||||
use crate::ops::serial::spsolve_csc_lower_triangular;
|
||||
use crate::ops::Op;
|
||||
use crate::pattern::SparsityPattern;
|
||||
use core::{iter, mem};
|
||||
use nalgebra::{DMatrix, DMatrixSlice, DMatrixSliceMut, RealField, Scalar};
|
||||
use std::fmt::{Display, Formatter};
|
||||
|
||||
/// A symbolic sparse Cholesky factorization of a CSC matrix.
|
||||
///
|
||||
|
@ -15,7 +15,7 @@ pub struct CscSymbolicCholesky {
|
|||
m_pattern: SparsityPattern,
|
||||
l_pattern: SparsityPattern,
|
||||
// u in this context is L^T, so that M = L L^T
|
||||
u_pattern: SparsityPattern
|
||||
u_pattern: SparsityPattern,
|
||||
}
|
||||
|
||||
impl CscSymbolicCholesky {
|
||||
|
@ -28,8 +28,11 @@ impl CscSymbolicCholesky {
|
|||
///
|
||||
/// Panics if the sparsity pattern is not square.
|
||||
pub fn factor(pattern: SparsityPattern) -> Self {
|
||||
assert_eq!(pattern.major_dim(), pattern.minor_dim(),
|
||||
"Major and minor dimensions must be the same (square matrix).");
|
||||
assert_eq!(
|
||||
pattern.major_dim(),
|
||||
pattern.minor_dim(),
|
||||
"Major and minor dimensions must be the same (square matrix)."
|
||||
);
|
||||
let (l_pattern, u_pattern) = nonzero_pattern(&pattern);
|
||||
Self {
|
||||
m_pattern: pattern,
|
||||
|
@ -65,7 +68,7 @@ pub struct CscCholesky<T> {
|
|||
l_factor: CscMatrix<T>,
|
||||
u_pattern: SparsityPattern,
|
||||
work_x: Vec<T>,
|
||||
work_c: Vec<usize>
|
||||
work_c: Vec<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
|
@ -100,16 +103,20 @@ impl<T: RealField> CscCholesky<T> {
|
|||
///
|
||||
/// Panics if the number of values differ from the number of non-zeros of the sparsity pattern
|
||||
/// of the matrix that was symbolically factored.
|
||||
pub fn factor_numerical(symbolic: CscSymbolicCholesky, values: &[T])
|
||||
-> Result<Self, CholeskyError>
|
||||
{
|
||||
assert_eq!(symbolic.l_pattern.nnz(), symbolic.u_pattern.nnz(),
|
||||
"u is just the transpose of l, so should have the same nnz");
|
||||
pub fn factor_numerical(
|
||||
symbolic: CscSymbolicCholesky,
|
||||
values: &[T],
|
||||
) -> Result<Self, CholeskyError> {
|
||||
assert_eq!(
|
||||
symbolic.l_pattern.nnz(),
|
||||
symbolic.u_pattern.nnz(),
|
||||
"u is just the transpose of l, so should have the same nnz"
|
||||
);
|
||||
|
||||
let l_nnz = symbolic.l_pattern.nnz();
|
||||
let l_values = vec![T::zero(); l_nnz];
|
||||
let l_factor = CscMatrix::try_from_pattern_and_values(symbolic.l_pattern, l_values)
|
||||
.unwrap();
|
||||
let l_factor =
|
||||
CscMatrix::try_from_pattern_and_values(symbolic.l_pattern, l_values).unwrap();
|
||||
|
||||
let (nrows, ncols) = (l_factor.nrows(), l_factor.ncols());
|
||||
|
||||
|
@ -169,7 +176,7 @@ impl<T: RealField> CscCholesky<T> {
|
|||
}
|
||||
|
||||
/// Returns the Cholesky factor `L`.
|
||||
pub fn take_l(self) -> CscMatrix<T> {
|
||||
pub fn take_l(self) -> CscMatrix<T> {
|
||||
self.l_factor
|
||||
}
|
||||
|
||||
|
@ -229,11 +236,9 @@ impl<T: RealField> CscCholesky<T> {
|
|||
|
||||
{
|
||||
let (offsets, _, values) = self.l_factor.csc_data_mut();
|
||||
*values
|
||||
.get_unchecked_mut(*offsets.get_unchecked(k)) = denom;
|
||||
*values.get_unchecked_mut(*offsets.get_unchecked(k)) = denom;
|
||||
}
|
||||
|
||||
|
||||
let mut col_k = self.l_factor.col_mut(k);
|
||||
let (col_k_rows, col_k_values) = col_k.rows_and_values_mut();
|
||||
let col_k_entries = col_k_rows.iter().zip(col_k_values);
|
||||
|
@ -269,19 +274,16 @@ impl<T: RealField> CscCholesky<T> {
|
|||
/// # Panics
|
||||
///
|
||||
/// Panics if `b` is not square.
|
||||
pub fn solve_mut<'a>(&'a self, b: impl Into<DMatrixSliceMut<'a, T>>)
|
||||
{
|
||||
pub fn solve_mut<'a>(&'a self, b: impl Into<DMatrixSliceMut<'a, T>>) {
|
||||
let expect_msg = "If the Cholesky factorization succeeded,\
|
||||
then the triangular solve should never fail";
|
||||
// Solve LY = B
|
||||
let mut y = b.into();
|
||||
spsolve_csc_lower_triangular(Op::NoOp(self.l()), &mut y)
|
||||
.expect(expect_msg);
|
||||
spsolve_csc_lower_triangular(Op::NoOp(self.l()), &mut y).expect(expect_msg);
|
||||
|
||||
// Solve L^T X = Y
|
||||
let mut x = y;
|
||||
spsolve_csc_lower_triangular(Op::Transpose(self.l()), &mut x)
|
||||
.expect(expect_msg);
|
||||
spsolve_csc_lower_triangular(Op::Transpose(self.l()), &mut x).expect(expect_msg);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -333,8 +335,8 @@ fn nonzero_pattern(m: &SparsityPattern) -> (SparsityPattern, SparsityPattern) {
|
|||
col_offsets.push(rows.len());
|
||||
}
|
||||
|
||||
let u_pattern = SparsityPattern::try_from_offsets_and_indices(nrows, ncols, col_offsets, rows)
|
||||
.unwrap();
|
||||
let u_pattern =
|
||||
SparsityPattern::try_from_offsets_and_indices(nrows, ncols, col_offsets, rows).unwrap();
|
||||
|
||||
// TODO: Avoid this transpose?
|
||||
let l_pattern = u_pattern.transpose();
|
||||
|
|
|
@ -135,13 +135,13 @@
|
|||
#![deny(unused_results)]
|
||||
#![deny(missing_docs)]
|
||||
|
||||
pub mod convert;
|
||||
pub mod coo;
|
||||
pub mod csc;
|
||||
pub mod csr;
|
||||
pub mod pattern;
|
||||
pub mod ops;
|
||||
pub mod convert;
|
||||
pub mod factorization;
|
||||
pub mod ops;
|
||||
pub mod pattern;
|
||||
|
||||
pub(crate) mod cs;
|
||||
|
||||
|
@ -151,16 +151,16 @@ pub mod proptest;
|
|||
#[cfg(feature = "compare")]
|
||||
mod matrixcompare;
|
||||
|
||||
use num_traits::Zero;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use num_traits::Zero;
|
||||
|
||||
/// Errors produced by functions that expect well-formed sparse format data.
|
||||
#[derive(Debug)]
|
||||
pub struct SparseFormatError {
|
||||
kind: SparseFormatErrorKind,
|
||||
// Currently we only use an underlying error for generating the `Display` impl
|
||||
error: Box<dyn Error>
|
||||
error: Box<dyn Error>,
|
||||
}
|
||||
|
||||
impl SparseFormatError {
|
||||
|
@ -170,10 +170,7 @@ impl SparseFormatError {
|
|||
}
|
||||
|
||||
pub(crate) fn from_kind_and_error(kind: SparseFormatErrorKind, error: Box<dyn Error>) -> Self {
|
||||
Self {
|
||||
kind,
|
||||
error
|
||||
}
|
||||
Self { kind, error }
|
||||
}
|
||||
|
||||
/// Helper functionality for more conveniently creating errors.
|
||||
|
@ -221,7 +218,7 @@ pub enum SparseEntry<'a, T> {
|
|||
/// is explicitly stored (a so-called "explicit zero").
|
||||
NonZero(&'a T),
|
||||
/// The entry is implicitly zero, i.e. it is not explicitly stored.
|
||||
Zero
|
||||
Zero,
|
||||
}
|
||||
|
||||
impl<'a, T: Clone + Zero> SparseEntry<'a, T> {
|
||||
|
@ -232,7 +229,7 @@ impl<'a, T: Clone + Zero> SparseEntry<'a, T> {
|
|||
pub fn to_value(self) -> T {
|
||||
match self {
|
||||
SparseEntry::NonZero(value) => value.clone(),
|
||||
SparseEntry::Zero => T::zero()
|
||||
SparseEntry::Zero => T::zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -248,7 +245,7 @@ pub enum SparseEntryMut<'a, T> {
|
|||
/// is explicitly stored (a so-called "explicit zero").
|
||||
NonZero(&'a mut T),
|
||||
/// The entry is implicitly zero i.e. it is not explicitly stored.
|
||||
Zero
|
||||
Zero,
|
||||
}
|
||||
|
||||
impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
|
||||
|
@ -259,7 +256,7 @@ impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
|
|||
pub fn to_value(self) -> T {
|
||||
match self {
|
||||
SparseEntryMut::NonZero(value) => value.clone(),
|
||||
SparseEntryMut::Zero => T::zero()
|
||||
SparseEntryMut::Zero => T::zero(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,36 +1,38 @@
|
|||
//! Implements core traits for use with `matrixcompare`.
|
||||
use crate::csr::CsrMatrix;
|
||||
use crate::coo::CooMatrix;
|
||||
use crate::csc::CscMatrix;
|
||||
use crate::csr::CsrMatrix;
|
||||
use matrixcompare_core;
|
||||
use matrixcompare_core::{Access, SparseAccess};
|
||||
use crate::coo::CooMatrix;
|
||||
|
||||
macro_rules! impl_matrix_for_csr_csc {
|
||||
($MatrixType:ident) => {
|
||||
impl<T: Clone> SparseAccess<T> for $MatrixType<T> {
|
||||
fn nnz(&self) -> usize {
|
||||
$MatrixType::nnz(self)
|
||||
impl<T: Clone> SparseAccess<T> for $MatrixType<T> {
|
||||
fn nnz(&self) -> usize {
|
||||
$MatrixType::nnz(self)
|
||||
}
|
||||
|
||||
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
|
||||
self.triplet_iter()
|
||||
.map(|(i, j, v)| (i, j, v.clone()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
|
||||
self.triplet_iter().map(|(i, j, v)| (i, j, v.clone())).collect()
|
||||
}
|
||||
}
|
||||
impl<T: Clone> matrixcompare_core::Matrix<T> for $MatrixType<T> {
|
||||
fn rows(&self) -> usize {
|
||||
self.nrows()
|
||||
}
|
||||
|
||||
impl<T: Clone> matrixcompare_core::Matrix<T> for $MatrixType<T> {
|
||||
fn rows(&self) -> usize {
|
||||
self.nrows()
|
||||
}
|
||||
fn cols(&self) -> usize {
|
||||
self.ncols()
|
||||
}
|
||||
|
||||
fn cols(&self) -> usize {
|
||||
self.ncols()
|
||||
fn access(&self) -> Access<T> {
|
||||
Access::Sparse(self)
|
||||
}
|
||||
}
|
||||
|
||||
fn access(&self) -> Access<T> {
|
||||
Access::Sparse(self)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_matrix_for_csr_csc!(CsrMatrix);
|
||||
|
@ -42,7 +44,9 @@ impl<T: Clone> SparseAccess<T> for CooMatrix<T> {
|
|||
}
|
||||
|
||||
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
|
||||
self.triplet_iter().map(|(i, j, v)| (i, j, v.clone())).collect()
|
||||
self.triplet_iter()
|
||||
.map(|(i, j, v)| (i, j, v.clone()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,15 +1,20 @@
|
|||
use crate::csr::CsrMatrix;
|
||||
use crate::csc::CscMatrix;
|
||||
use crate::csr::CsrMatrix;
|
||||
|
||||
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg};
|
||||
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_csr_pattern, spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense, spmm_csc_pattern};
|
||||
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, MatrixMN, Dim,
|
||||
Dynamic, DefaultAllocator, U1};
|
||||
use nalgebra::allocator::{Allocator};
|
||||
use nalgebra::constraint::{DimEq, ShapeConstraint};
|
||||
use num_traits::{Zero, One};
|
||||
use crate::ops::{Op};
|
||||
use crate::ops::serial::{
|
||||
spadd_csc_prealloc, spadd_csr_prealloc, spadd_pattern, spmm_csc_dense, spmm_csc_pattern,
|
||||
spmm_csc_prealloc, spmm_csr_dense, spmm_csr_pattern, spmm_csr_prealloc,
|
||||
};
|
||||
use crate::ops::Op;
|
||||
use nalgebra::allocator::Allocator;
|
||||
use nalgebra::base::storage::Storage;
|
||||
use nalgebra::constraint::{DimEq, ShapeConstraint};
|
||||
use nalgebra::{
|
||||
ClosedAdd, ClosedDiv, ClosedMul, ClosedSub, DefaultAllocator, Dim, Dynamic, Matrix, MatrixMN,
|
||||
Scalar, U1,
|
||||
};
|
||||
use num_traits::{One, Zero};
|
||||
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Neg, Sub};
|
||||
|
||||
/// Helper macro for implementing binary operators for different matrix types
|
||||
/// See below for usage.
|
||||
|
@ -188,7 +193,7 @@ macro_rules! impl_neg {
|
|||
($matrix_type:ident) => {
|
||||
impl<T> Neg for $matrix_type<T>
|
||||
where
|
||||
T: Scalar + Neg<Output=T>
|
||||
T: Scalar + Neg<Output = T>,
|
||||
{
|
||||
type Output = $matrix_type<T>;
|
||||
|
||||
|
@ -202,7 +207,7 @@ macro_rules! impl_neg {
|
|||
|
||||
impl<'a, T> Neg for &'a $matrix_type<T>
|
||||
where
|
||||
T: Scalar + Neg<Output=T>
|
||||
T: Scalar + Neg<Output = T>,
|
||||
{
|
||||
type Output = $matrix_type<T>;
|
||||
|
||||
|
@ -211,10 +216,10 @@ macro_rules! impl_neg {
|
|||
// obtain both the sparsity pattern and values from the matrix,
|
||||
// and then modify the values before creating a new matrix from the pattern
|
||||
// and negated values.
|
||||
- self.clone()
|
||||
-self.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_neg!(CsrMatrix);
|
||||
|
|
|
@ -148,13 +148,14 @@ impl<T> Op<T> {
|
|||
pub fn as_ref(&self) -> Op<&T> {
|
||||
match self {
|
||||
Op::NoOp(obj) => Op::NoOp(&obj),
|
||||
Op::Transpose(obj) => Op::Transpose(&obj)
|
||||
Op::Transpose(obj) => Op::Transpose(&obj),
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the underlying data type.
|
||||
pub fn convert<U>(self) -> Op<U>
|
||||
where T: Into<U>
|
||||
where
|
||||
T: Into<U>,
|
||||
{
|
||||
self.map_same_op(T::into)
|
||||
}
|
||||
|
@ -163,7 +164,7 @@ impl<T> Op<T> {
|
|||
pub fn map_same_op<U, F: FnOnce(T) -> U>(self, f: F) -> Op<U> {
|
||||
match self {
|
||||
Op::NoOp(obj) => Op::NoOp(f(obj)),
|
||||
Op::Transpose(obj) => Op::Transpose(f(obj))
|
||||
Op::Transpose(obj) => Op::Transpose(f(obj)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -181,7 +182,7 @@ impl<T> Op<T> {
|
|||
pub fn transposed(self) -> Self {
|
||||
match self {
|
||||
Op::NoOp(obj) => Op::Transpose(obj),
|
||||
Op::Transpose(obj) => Op::NoOp(obj)
|
||||
Op::Transpose(obj) => Op::NoOp(obj),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -191,4 +192,3 @@ impl<T> From<T> for Op<T> {
|
|||
Self::NoOp(obj)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
use crate::cs::CsMatrix;
|
||||
use crate::ops::serial::{OperationError, OperationErrorKind};
|
||||
use crate::ops::Op;
|
||||
use crate::ops::serial::{OperationErrorKind, OperationError};
|
||||
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice};
|
||||
use num_traits::{Zero, One};
|
||||
use crate::SparseEntryMut;
|
||||
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
|
||||
use num_traits::{One, Zero};
|
||||
|
||||
fn spmm_cs_unexpected_entry() -> OperationError {
|
||||
OperationError::from_kind_and_message(
|
||||
OperationErrorKind::InvalidPattern,
|
||||
String::from("Found unexpected entry that is not present in `c`."))
|
||||
String::from("Found unexpected entry that is not present in `c`."),
|
||||
)
|
||||
}
|
||||
|
||||
/// Helper functionality for implementing CSR/CSC SPMM.
|
||||
|
@ -24,12 +25,12 @@ pub fn spmm_cs_prealloc<T>(
|
|||
c: &mut CsMatrix<T>,
|
||||
alpha: T,
|
||||
a: &CsMatrix<T>,
|
||||
b: &CsMatrix<T>)
|
||||
-> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
b: &CsMatrix<T>,
|
||||
) -> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
for i in 0 .. c.pattern().major_dim() {
|
||||
for i in 0..c.pattern().major_dim() {
|
||||
let a_lane_i = a.get_lane(i).unwrap();
|
||||
let mut c_lane_i = c.get_lane_mut(i).unwrap();
|
||||
for c_ij in c_lane_i.values_mut() {
|
||||
|
@ -42,14 +43,15 @@ pub fn spmm_cs_prealloc<T>(
|
|||
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
|
||||
for (j, b_kj) in b_lane_k.minor_indices().iter().zip(b_lane_k.values()) {
|
||||
// Determine the location in C to append the value
|
||||
let (c_local_idx, _) = c_lane_i_cols.iter()
|
||||
let (c_local_idx, _) = c_lane_i_cols
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, c_col)| *c_col == j)
|
||||
.ok_or_else(spmm_cs_unexpected_entry)?;
|
||||
|
||||
c_lane_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone();
|
||||
c_lane_i_cols = &c_lane_i_cols[c_local_idx ..];
|
||||
c_lane_i_values = &mut c_lane_i_values[c_local_idx ..];
|
||||
c_lane_i_cols = &c_lane_i_cols[c_local_idx..];
|
||||
c_lane_i_values = &mut c_lane_i_values[c_local_idx..];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -60,17 +62,19 @@ pub fn spmm_cs_prealloc<T>(
|
|||
fn spadd_cs_unexpected_entry() -> OperationError {
|
||||
OperationError::from_kind_and_message(
|
||||
OperationErrorKind::InvalidPattern,
|
||||
String::from("Found entry in `op(a)` that is not present in `c`."))
|
||||
String::from("Found entry in `op(a)` that is not present in `c`."),
|
||||
)
|
||||
}
|
||||
|
||||
/// Helper functionality for implementing CSR/CSC SPADD.
|
||||
pub fn spadd_cs_prealloc<T>(beta: T,
|
||||
c: &mut CsMatrix<T>,
|
||||
alpha: T,
|
||||
a: Op<&CsMatrix<T>>)
|
||||
-> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
pub fn spadd_cs_prealloc<T>(
|
||||
beta: T,
|
||||
c: &mut CsMatrix<T>,
|
||||
alpha: T,
|
||||
a: Op<&CsMatrix<T>>,
|
||||
) -> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
match a {
|
||||
Op::NoOp(a) => {
|
||||
|
@ -88,13 +92,14 @@ pub fn spadd_cs_prealloc<T>(beta: T,
|
|||
// TODO: Use exponential search instead of linear search.
|
||||
// If C has substantially more entries in the row than A, then a line search
|
||||
// will needlessly visit many entries in C.
|
||||
let (c_idx, _) = c_minors.iter()
|
||||
let (c_idx, _) = c_minors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, c_col)| *c_col == a_col)
|
||||
.ok_or_else(spadd_cs_unexpected_entry)?;
|
||||
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
|
||||
c_minors = &c_minors[c_idx ..];
|
||||
c_vals = &mut c_vals[c_idx ..];
|
||||
c_minors = &c_minors[c_idx..];
|
||||
c_vals = &mut c_vals[c_idx..];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -110,7 +115,7 @@ pub fn spadd_cs_prealloc<T>(beta: T,
|
|||
let a_val = a_val.inlined_clone();
|
||||
let alpha = alpha.inlined_clone();
|
||||
match c.get_entry_mut(j, i).unwrap() {
|
||||
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
|
||||
SparseEntryMut::NonZero(c_ji) => *c_ji += alpha * a_val,
|
||||
SparseEntryMut::Zero => return Err(spadd_cs_unexpected_entry()),
|
||||
}
|
||||
}
|
||||
|
@ -124,13 +129,14 @@ pub fn spadd_cs_prealloc<T>(beta: T,
|
|||
///
|
||||
/// The implementation essentially assumes that `a` is a CSR matrix. To use it with CSC matrices,
|
||||
/// the transposed operation must be specified for the CSC matrix.
|
||||
pub fn spmm_cs_dense<T>(beta: T,
|
||||
mut c: DMatrixSliceMut<T>,
|
||||
alpha: T,
|
||||
a: Op<&CsMatrix<T>>,
|
||||
b: Op<DMatrixSlice<T>>)
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
pub fn spmm_cs_dense<T>(
|
||||
beta: T,
|
||||
mut c: DMatrixSliceMut<T>,
|
||||
alpha: T,
|
||||
a: Op<&CsMatrix<T>>,
|
||||
b: Op<DMatrixSlice<T>>,
|
||||
) where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
match a {
|
||||
Op::NoOp(a) => {
|
||||
|
@ -139,17 +145,17 @@ pub fn spmm_cs_dense<T>(beta: T,
|
|||
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.lane_iter()) {
|
||||
let mut dot_ij = T::zero();
|
||||
for (&k, a_ik) in a_row_i.minor_indices().iter().zip(a_row_i.values()) {
|
||||
let b_contrib =
|
||||
match b {
|
||||
Op::NoOp(ref b) => b.index((k, j)),
|
||||
Op::Transpose(ref b) => b.index((j, k))
|
||||
};
|
||||
let b_contrib = match b {
|
||||
Op::NoOp(ref b) => b.index((k, j)),
|
||||
Op::Transpose(ref b) => b.index((j, k)),
|
||||
};
|
||||
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
|
||||
}
|
||||
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;
|
||||
*c_ij = beta.inlined_clone() * c_ij.inlined_clone()
|
||||
+ alpha.inlined_clone() * dot_ij;
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
Op::Transpose(a) => {
|
||||
// In this case, we have to pre-multiply C by beta
|
||||
c *= beta;
|
||||
|
@ -165,17 +171,16 @@ pub fn spmm_cs_dense<T>(beta: T,
|
|||
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
|
||||
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
|
||||
}
|
||||
},
|
||||
}
|
||||
Op::Transpose(ref b) => {
|
||||
let b_col_k = b.column(k);
|
||||
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
|
||||
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use crate::csc::CscMatrix;
|
||||
use crate::ops::Op;
|
||||
use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc};
|
||||
use crate::ops::serial::cs::{spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc};
|
||||
use crate::ops::serial::{OperationError, OperationErrorKind};
|
||||
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice, RealField};
|
||||
use num_traits::{Zero, One};
|
||||
use crate::ops::Op;
|
||||
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, RealField, Scalar};
|
||||
use num_traits::{One, Zero};
|
||||
|
||||
use std::borrow::Cow;
|
||||
|
||||
|
@ -12,25 +12,27 @@ use std::borrow::Cow;
|
|||
/// # Panics
|
||||
///
|
||||
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||
pub fn spmm_csc_dense<'a, T>(beta: T,
|
||||
c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||
alpha: T,
|
||||
a: Op<&CscMatrix<T>>,
|
||||
b: Op<impl Into<DMatrixSlice<'a, T>>>)
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
pub fn spmm_csc_dense<'a, T>(
|
||||
beta: T,
|
||||
c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||
alpha: T,
|
||||
a: Op<&CscMatrix<T>>,
|
||||
b: Op<impl Into<DMatrixSlice<'a, T>>>,
|
||||
) where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
let b = b.convert();
|
||||
spmm_csc_dense_(beta, c.into(), alpha, a, b)
|
||||
}
|
||||
|
||||
fn spmm_csc_dense_<T>(beta: T,
|
||||
c: DMatrixSliceMut<T>,
|
||||
alpha: T,
|
||||
a: Op<&CscMatrix<T>>,
|
||||
b: Op<DMatrixSlice<T>>)
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
fn spmm_csc_dense_<T>(
|
||||
beta: T,
|
||||
c: DMatrixSliceMut<T>,
|
||||
alpha: T,
|
||||
a: Op<&CscMatrix<T>>,
|
||||
b: Op<DMatrixSlice<T>>,
|
||||
) where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
assert_compatible_spmm_dims!(c, a, b);
|
||||
// Need to interpret matrix as transposed since the spmm_cs_dense function assumes CSR layout
|
||||
|
@ -46,19 +48,19 @@ fn spmm_csc_dense_<T>(beta: T,
|
|||
/// # Panics
|
||||
///
|
||||
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||
pub fn spadd_csc_prealloc<T>(beta: T,
|
||||
c: &mut CscMatrix<T>,
|
||||
alpha: T,
|
||||
a: Op<&CscMatrix<T>>)
|
||||
-> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
pub fn spadd_csc_prealloc<T>(
|
||||
beta: T,
|
||||
c: &mut CscMatrix<T>,
|
||||
alpha: T,
|
||||
a: Op<&CscMatrix<T>>,
|
||||
) -> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
assert_compatible_spadd_dims!(c, a);
|
||||
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
|
||||
}
|
||||
|
||||
|
||||
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
|
||||
///
|
||||
/// # Errors
|
||||
|
@ -74,10 +76,10 @@ pub fn spmm_csc_prealloc<T>(
|
|||
c: &mut CscMatrix<T>,
|
||||
alpha: T,
|
||||
a: Op<&CscMatrix<T>>,
|
||||
b: Op<&CscMatrix<T>>)
|
||||
-> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
b: Op<&CscMatrix<T>>,
|
||||
) -> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
assert_compatible_spmm_dims!(c, a, b);
|
||||
|
||||
|
@ -87,7 +89,7 @@ pub fn spmm_csc_prealloc<T>(
|
|||
(NoOp(ref a), NoOp(ref b)) => {
|
||||
// Note: We have to reverse the order for CSC matrices
|
||||
spmm_cs_prealloc(beta, &mut c.cs, alpha, &b.cs, &a.cs)
|
||||
},
|
||||
}
|
||||
_ => {
|
||||
// Currently we handle transposition by explicitly precomputing transposed matrices
|
||||
// and calling the operation again without transposition
|
||||
|
@ -99,7 +101,9 @@ pub fn spmm_csc_prealloc<T>(
|
|||
(NoOp(_), NoOp(_)) => unreachable!(),
|
||||
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
|
||||
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
|
||||
(Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose()))
|
||||
(Transpose(ref a), Transpose(ref b)) => {
|
||||
(Owned(a.transpose()), Owned(b.transpose()))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -121,13 +125,20 @@ pub fn spmm_csc_prealloc<T>(
|
|||
/// Panics if `L` is not square, or if `L` and `B` are not dimensionally compatible.
|
||||
pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
|
||||
l: Op<&CscMatrix<T>>,
|
||||
b: impl Into<DMatrixSliceMut<'a, T>>)
|
||||
-> Result<(), OperationError>
|
||||
{
|
||||
b: impl Into<DMatrixSliceMut<'a, T>>,
|
||||
) -> Result<(), OperationError> {
|
||||
let b = b.into();
|
||||
let l_matrix = l.into_inner();
|
||||
assert_eq!(l_matrix.nrows(), l_matrix.ncols(), "Matrix must be square for triangular solve.");
|
||||
assert_eq!(l_matrix.nrows(), b.nrows(), "Dimension mismatch in sparse lower triangular solver.");
|
||||
assert_eq!(
|
||||
l_matrix.nrows(),
|
||||
l_matrix.ncols(),
|
||||
"Matrix must be square for triangular solve."
|
||||
);
|
||||
assert_eq!(
|
||||
l_matrix.nrows(),
|
||||
b.nrows(),
|
||||
"Dimension mismatch in sparse lower triangular solver."
|
||||
);
|
||||
match l {
|
||||
Op::NoOp(a) => spsolve_csc_lower_triangular_no_transpose(a, b),
|
||||
Op::Transpose(a) => spsolve_csc_lower_triangular_transpose(a, b),
|
||||
|
@ -136,16 +147,15 @@ pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
|
|||
|
||||
fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
|
||||
l: &CscMatrix<T>,
|
||||
b: DMatrixSliceMut<T>)
|
||||
-> Result<(), OperationError>
|
||||
{
|
||||
b: DMatrixSliceMut<T>,
|
||||
) -> Result<(), OperationError> {
|
||||
let mut x = b;
|
||||
|
||||
// Solve column-by-column
|
||||
for j in 0 .. x.ncols() {
|
||||
for j in 0..x.ncols() {
|
||||
let mut x_col_j = x.column_mut(j);
|
||||
|
||||
for k in 0 .. l.ncols() {
|
||||
for k in 0..l.ncols() {
|
||||
let l_col_k = l.col(k);
|
||||
|
||||
// Skip entries above the diagonal
|
||||
|
@ -163,8 +173,8 @@ fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
|
|||
// Copy value after updating (so we don't run into the borrow checker)
|
||||
let x_kj = x_col_j[k];
|
||||
|
||||
let row_indices = &l_col_k.row_indices()[(diag_csc_index + 1) ..];
|
||||
let l_values = &l_col_k.values()[(diag_csc_index + 1) ..];
|
||||
let row_indices = &l_col_k.row_indices()[(diag_csc_index + 1)..];
|
||||
let l_values = &l_col_k.values()[(diag_csc_index + 1)..];
|
||||
|
||||
// Note: The remaining entries are below the diagonal
|
||||
for (&i, l_ik) in row_indices.iter().zip(l_values) {
|
||||
|
@ -187,24 +197,26 @@ fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
|
|||
|
||||
fn spsolve_encountered_zero_diagonal() -> Result<(), OperationError> {
|
||||
let message = "Matrix contains at least one diagonal entry that is zero.";
|
||||
Err(OperationError::from_kind_and_message(OperationErrorKind::Singular, String::from(message)))
|
||||
Err(OperationError::from_kind_and_message(
|
||||
OperationErrorKind::Singular,
|
||||
String::from(message),
|
||||
))
|
||||
}
|
||||
|
||||
fn spsolve_csc_lower_triangular_transpose<T: RealField>(
|
||||
l: &CscMatrix<T>,
|
||||
b: DMatrixSliceMut<T>)
|
||||
-> Result<(), OperationError>
|
||||
{
|
||||
b: DMatrixSliceMut<T>,
|
||||
) -> Result<(), OperationError> {
|
||||
let mut x = b;
|
||||
|
||||
// Solve column-by-column
|
||||
for j in 0 .. x.ncols() {
|
||||
for j in 0..x.ncols() {
|
||||
let mut x_col_j = x.column_mut(j);
|
||||
|
||||
// Due to the transposition, we're essentially solving an upper triangular system,
|
||||
// and the columns in our matrix become rows
|
||||
|
||||
for i in (0 .. l.ncols()).rev() {
|
||||
for i in (0..l.ncols()).rev() {
|
||||
let l_col_i = l.col(i);
|
||||
|
||||
// Skip entries above the diagonal
|
||||
|
@ -220,8 +232,8 @@ fn spsolve_csc_lower_triangular_transpose<T: RealField>(
|
|||
// Copy value after updating (so we don't run into the borrow checker)
|
||||
let mut x_ii = x_col_j[i];
|
||||
|
||||
let row_indices = &l_col_i.row_indices()[(diag_csc_index + 1) ..];
|
||||
let a_values = &l_col_i.values()[(diag_csc_index + 1) ..];
|
||||
let row_indices = &l_col_i.row_indices()[(diag_csc_index + 1)..];
|
||||
let a_values = &l_col_i.values()[(diag_csc_index + 1)..];
|
||||
|
||||
// Note: The remaining entries are below the diagonal
|
||||
for (&k, &l_ki) in row_indices.iter().zip(a_values) {
|
||||
|
|
|
@ -1,31 +1,33 @@
|
|||
use crate::csr::CsrMatrix;
|
||||
use crate::ops::{Op};
|
||||
use crate::ops::serial::{OperationError};
|
||||
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
|
||||
use num_traits::{Zero, One};
|
||||
use crate::ops::serial::cs::{spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc};
|
||||
use crate::ops::serial::OperationError;
|
||||
use crate::ops::Op;
|
||||
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
|
||||
use num_traits::{One, Zero};
|
||||
use std::borrow::Cow;
|
||||
use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc};
|
||||
|
||||
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
|
||||
pub fn spmm_csr_dense<'a, T>(beta: T,
|
||||
c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||
alpha: T,
|
||||
a: Op<&CsrMatrix<T>>,
|
||||
b: Op<impl Into<DMatrixSlice<'a, T>>>)
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
pub fn spmm_csr_dense<'a, T>(
|
||||
beta: T,
|
||||
c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||
alpha: T,
|
||||
a: Op<&CsrMatrix<T>>,
|
||||
b: Op<impl Into<DMatrixSlice<'a, T>>>,
|
||||
) where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
let b = b.convert();
|
||||
spmm_csr_dense_(beta, c.into(), alpha, a, b)
|
||||
}
|
||||
|
||||
fn spmm_csr_dense_<T>(beta: T,
|
||||
c: DMatrixSliceMut<T>,
|
||||
alpha: T,
|
||||
a: Op<&CsrMatrix<T>>,
|
||||
b: Op<DMatrixSlice<T>>)
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
fn spmm_csr_dense_<T>(
|
||||
beta: T,
|
||||
c: DMatrixSliceMut<T>,
|
||||
alpha: T,
|
||||
a: Op<&CsrMatrix<T>>,
|
||||
b: Op<DMatrixSlice<T>>,
|
||||
) where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
assert_compatible_spmm_dims!(c, a, b);
|
||||
spmm_cs_dense(beta, c, alpha, a.map_same_op(|a| &a.cs), b)
|
||||
|
@ -41,13 +43,14 @@ where
|
|||
/// # Panics
|
||||
///
|
||||
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||
pub fn spadd_csr_prealloc<T>(beta: T,
|
||||
c: &mut CsrMatrix<T>,
|
||||
alpha: T,
|
||||
a: Op<&CsrMatrix<T>>)
|
||||
-> Result<(), OperationError>
|
||||
pub fn spadd_csr_prealloc<T>(
|
||||
beta: T,
|
||||
c: &mut CsrMatrix<T>,
|
||||
alpha: T,
|
||||
a: Op<&CsrMatrix<T>>,
|
||||
) -> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
assert_compatible_spadd_dims!(c, a);
|
||||
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
|
||||
|
@ -67,19 +70,17 @@ pub fn spmm_csr_prealloc<T>(
|
|||
c: &mut CsrMatrix<T>,
|
||||
alpha: T,
|
||||
a: Op<&CsrMatrix<T>>,
|
||||
b: Op<&CsrMatrix<T>>)
|
||||
-> Result<(), OperationError>
|
||||
b: Op<&CsrMatrix<T>>,
|
||||
) -> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
assert_compatible_spmm_dims!(c, a, b);
|
||||
|
||||
use Op::{NoOp, Transpose};
|
||||
|
||||
match (&a, &b) {
|
||||
(NoOp(ref a), NoOp(ref b)) => {
|
||||
spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs)
|
||||
},
|
||||
(NoOp(ref a), NoOp(ref b)) => spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs),
|
||||
_ => {
|
||||
// Currently we handle transposition by explicitly precomputing transposed matrices
|
||||
// and calling the operation again without transposition
|
||||
|
@ -93,7 +94,9 @@ where
|
|||
(NoOp(_), NoOp(_)) => unreachable!(),
|
||||
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
|
||||
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
|
||||
(Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose()))
|
||||
(Transpose(ref a), Transpose(ref b)) => {
|
||||
(Owned(a.transpose()), Owned(b.transpose()))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -101,4 +104,3 @@ where
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -10,33 +10,31 @@
|
|||
|
||||
#[macro_use]
|
||||
macro_rules! assert_compatible_spmm_dims {
|
||||
($c:expr, $a:expr, $b:expr) => {
|
||||
{
|
||||
use crate::ops::Op::{NoOp, Transpose};
|
||||
match (&$a, &$b) {
|
||||
(NoOp(ref a), NoOp(ref b)) => {
|
||||
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
|
||||
assert_eq!(a.ncols(), b.nrows(), "A.ncols() != B.nrows()");
|
||||
},
|
||||
(Transpose(ref a), NoOp(ref b)) => {
|
||||
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
|
||||
assert_eq!(a.nrows(), b.nrows(), "A.nrows() != B.nrows()");
|
||||
},
|
||||
(NoOp(ref a), Transpose(ref b)) => {
|
||||
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
|
||||
assert_eq!(a.ncols(), b.ncols(), "A.ncols() != B.ncols()");
|
||||
},
|
||||
(Transpose(ref a), Transpose(ref b)) => {
|
||||
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
|
||||
assert_eq!(a.nrows(), b.ncols(), "A.nrows() != B.ncols()");
|
||||
}
|
||||
($c:expr, $a:expr, $b:expr) => {{
|
||||
use crate::ops::Op::{NoOp, Transpose};
|
||||
match (&$a, &$b) {
|
||||
(NoOp(ref a), NoOp(ref b)) => {
|
||||
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
|
||||
assert_eq!(a.ncols(), b.nrows(), "A.ncols() != B.nrows()");
|
||||
}
|
||||
(Transpose(ref a), NoOp(ref b)) => {
|
||||
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
|
||||
assert_eq!(a.nrows(), b.nrows(), "A.nrows() != B.nrows()");
|
||||
}
|
||||
(NoOp(ref a), Transpose(ref b)) => {
|
||||
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
|
||||
assert_eq!(a.ncols(), b.ncols(), "A.ncols() != B.ncols()");
|
||||
}
|
||||
(Transpose(ref a), Transpose(ref b)) => {
|
||||
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
|
||||
assert_eq!(a.nrows(), b.ncols(), "A.nrows() != B.ncols()");
|
||||
}
|
||||
}
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
#[macro_use]
|
||||
|
@ -47,32 +45,31 @@ macro_rules! assert_compatible_spadd_dims {
|
|||
Op::NoOp(a) => {
|
||||
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||
assert_eq!($c.ncols(), a.ncols(), "C.ncols() != A.ncols()");
|
||||
},
|
||||
}
|
||||
Op::Transpose(a) => {
|
||||
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||
assert_eq!($c.ncols(), a.nrows(), "C.ncols() != A.nrows()");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
mod cs;
|
||||
mod csc;
|
||||
mod csr;
|
||||
mod pattern;
|
||||
mod cs;
|
||||
|
||||
pub use csc::*;
|
||||
pub use csr::*;
|
||||
pub use pattern::*;
|
||||
use std::fmt::Formatter;
|
||||
use std::fmt;
|
||||
use std::fmt::Formatter;
|
||||
|
||||
/// A description of the error that occurred during an arithmetic operation.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OperationError {
|
||||
error_kind: OperationErrorKind,
|
||||
message: String
|
||||
message: String,
|
||||
}
|
||||
|
||||
/// The different kinds of operation errors that may occur.
|
||||
|
@ -92,7 +89,10 @@ pub enum OperationErrorKind {
|
|||
|
||||
impl OperationError {
|
||||
fn from_kind_and_message(error_type: OperationErrorKind, message: String) -> Self {
|
||||
Self { error_kind: error_type, message }
|
||||
Self {
|
||||
error_kind: error_type,
|
||||
message,
|
||||
}
|
||||
}
|
||||
|
||||
/// The operation error kind.
|
||||
|
@ -110,8 +110,12 @@ impl fmt::Display for OperationError {
|
|||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Sparse matrix operation error: ")?;
|
||||
match self.kind() {
|
||||
OperationErrorKind::InvalidPattern => { write!(f, "InvalidPattern")?; }
|
||||
OperationErrorKind::Singular => { write!(f, "Singular")?; }
|
||||
OperationErrorKind::InvalidPattern => {
|
||||
write!(f, "InvalidPattern")?;
|
||||
}
|
||||
OperationErrorKind::Singular => {
|
||||
write!(f, "Singular")?;
|
||||
}
|
||||
}
|
||||
write!(f, " Message: {}", self.message)
|
||||
}
|
||||
|
|
|
@ -12,11 +12,17 @@ use std::iter;
|
|||
/// # Panics
|
||||
///
|
||||
/// Panics if the patterns do not have the same major and minor dimensions.
|
||||
pub fn spadd_pattern(a: &SparsityPattern,
|
||||
b: &SparsityPattern) -> SparsityPattern
|
||||
{
|
||||
assert_eq!(a.major_dim(), b.major_dim(), "Patterns must have identical major dimensions.");
|
||||
assert_eq!(a.minor_dim(), b.minor_dim(), "Patterns must have identical minor dimensions.");
|
||||
pub fn spadd_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||
assert_eq!(
|
||||
a.major_dim(),
|
||||
b.major_dim(),
|
||||
"Patterns must have identical major dimensions."
|
||||
);
|
||||
assert_eq!(
|
||||
a.minor_dim(),
|
||||
b.minor_dim(),
|
||||
"Patterns must have identical minor dimensions."
|
||||
);
|
||||
|
||||
let mut offsets = Vec::new();
|
||||
let mut indices = Vec::new();
|
||||
|
@ -25,7 +31,7 @@ pub fn spadd_pattern(a: &SparsityPattern,
|
|||
|
||||
offsets.push(0);
|
||||
|
||||
for lane_idx in 0 .. a.major_dim() {
|
||||
for lane_idx in 0..a.major_dim() {
|
||||
let lane_a = a.lane(lane_idx);
|
||||
let lane_b = b.lane(lane_idx);
|
||||
indices.extend(iterate_union(lane_a, lane_b));
|
||||
|
@ -33,8 +39,7 @@ pub fn spadd_pattern(a: &SparsityPattern,
|
|||
}
|
||||
|
||||
// TODO: Consider circumventing format checks? (requires unsafe, should benchmark first)
|
||||
SparsityPattern::try_from_offsets_and_indices(
|
||||
a.major_dim(), a.minor_dim(), offsets, indices)
|
||||
SparsityPattern::try_from_offsets_and_indices(a.major_dim(), a.minor_dim(), offsets, indices)
|
||||
.expect("Internal error: Pattern must be valid by definition")
|
||||
}
|
||||
|
||||
|
@ -66,7 +71,11 @@ pub fn spmm_csc_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPat
|
|||
/// Panics if the patterns, when interpreted as CSR patterns, are not compatible for
|
||||
/// matrix multiplication.
|
||||
pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||
assert_eq!(a.minor_dim(), b.major_dim(), "a and b must have compatible dimensions");
|
||||
assert_eq!(
|
||||
a.minor_dim(),
|
||||
b.major_dim(),
|
||||
"a and b must have compatible dimensions"
|
||||
);
|
||||
|
||||
let mut offsets = Vec::new();
|
||||
let mut indices = Vec::new();
|
||||
|
@ -78,7 +87,7 @@ pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPat
|
|||
// (would cut memory use to 1/8, which might help reduce cache misses)
|
||||
let mut visited = vec![false; b.minor_dim()];
|
||||
|
||||
for i in 0 .. a.major_dim() {
|
||||
for i in 0..a.major_dim() {
|
||||
let a_lane_i = a.lane(i);
|
||||
let c_lane_i_offset = *offsets.last().unwrap();
|
||||
for &k in a_lane_i {
|
||||
|
@ -93,7 +102,7 @@ pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPat
|
|||
}
|
||||
}
|
||||
|
||||
let c_lane_i = &mut indices[c_lane_i_offset ..];
|
||||
let c_lane_i = &mut indices[c_lane_i_offset..];
|
||||
c_lane_i.sort_unstable();
|
||||
|
||||
// Reset visits so that visited[j] == false for all j for the next major lane
|
||||
|
@ -110,21 +119,23 @@ pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPat
|
|||
|
||||
/// Iterate over the union of the two sets represented by sorted slices
|
||||
/// (with unique elements)
|
||||
fn iterate_union<'a>(mut sorted_a: &'a [usize],
|
||||
mut sorted_b: &'a [usize]) -> impl Iterator<Item=usize> + 'a {
|
||||
fn iterate_union<'a>(
|
||||
mut sorted_a: &'a [usize],
|
||||
mut sorted_b: &'a [usize],
|
||||
) -> impl Iterator<Item = usize> + 'a {
|
||||
iter::from_fn(move || {
|
||||
if let (Some(a_item), Some(b_item)) = (sorted_a.first(), sorted_b.first()) {
|
||||
let item = if a_item < b_item {
|
||||
sorted_a = &sorted_a[1 ..];
|
||||
sorted_a = &sorted_a[1..];
|
||||
a_item
|
||||
} else if b_item < a_item {
|
||||
sorted_b = &sorted_b[1 ..];
|
||||
sorted_b = &sorted_b[1..];
|
||||
b_item
|
||||
} else {
|
||||
// Both lists contain the same element, advance both slices to avoid
|
||||
// duplicate entries in the result
|
||||
sorted_a = &sorted_a[1 ..];
|
||||
sorted_b = &sorted_b[1 ..];
|
||||
sorted_a = &sorted_a[1..];
|
||||
sorted_b = &sorted_b[1..];
|
||||
a_item
|
||||
};
|
||||
Some(*item)
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
//! Sparsity patterns for CSR and CSC matrices.
|
||||
use crate::SparseFormatError;
|
||||
use std::fmt;
|
||||
use std::error::Error;
|
||||
use crate::cs::transpose_cs;
|
||||
use crate::SparseFormatError;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
/// A representation of the sparsity pattern of a CSR or CSC matrix.
|
||||
///
|
||||
|
@ -137,7 +137,7 @@ impl SparsityPattern {
|
|||
// minor indices within a lane are sorted, unique. In addition, each minor index
|
||||
// must be in bounds with respect to the minor dimension.
|
||||
{
|
||||
for lane_idx in 0 .. major_dim {
|
||||
for lane_idx in 0..major_dim {
|
||||
let range_start = major_offsets[lane_idx];
|
||||
let range_end = major_offsets[lane_idx + 1];
|
||||
|
||||
|
@ -146,7 +146,7 @@ impl SparsityPattern {
|
|||
return Err(NonmonotonicOffsets);
|
||||
}
|
||||
|
||||
let minor_indices = &minor_indices[range_start .. range_end];
|
||||
let minor_indices = &minor_indices[range_start..range_end];
|
||||
|
||||
// We test for in-bounds, uniqueness and monotonicity at the same time
|
||||
// to ensure that we only visit each minor index once
|
||||
|
@ -232,17 +232,20 @@ impl SparsityPattern {
|
|||
// By using unit () values, we can use the same routines as for CSR/CSC matrices
|
||||
let values = vec![(); self.nnz()];
|
||||
let (new_offsets, new_indices, _) = transpose_cs(
|
||||
self.major_dim(),
|
||||
self.minor_dim(),
|
||||
self.major_offsets(),
|
||||
self.minor_indices(),
|
||||
&values);
|
||||
self.major_dim(),
|
||||
self.minor_dim(),
|
||||
self.major_offsets(),
|
||||
self.minor_indices(),
|
||||
&values,
|
||||
);
|
||||
// TODO: Skip checks
|
||||
Self::try_from_offsets_and_indices(self.minor_dim(),
|
||||
self.major_dim(),
|
||||
new_offsets,
|
||||
new_indices)
|
||||
.expect("Internal error: Transpose should never fail.")
|
||||
Self::try_from_offsets_and_indices(
|
||||
self.minor_dim(),
|
||||
self.major_dim(),
|
||||
new_offsets,
|
||||
new_indices,
|
||||
)
|
||||
.expect("Internal error: Transpose should never fail.")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -275,22 +278,25 @@ pub enum SparsityPatternFormatError {
|
|||
|
||||
impl From<SparsityPatternFormatError> for SparseFormatError {
|
||||
fn from(err: SparsityPatternFormatError) -> Self {
|
||||
use SparsityPatternFormatError::*;
|
||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||
use crate::SparseFormatErrorKind;
|
||||
use crate::SparseFormatErrorKind::*;
|
||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||
use SparsityPatternFormatError::*;
|
||||
match err {
|
||||
InvalidOffsetArrayLength
|
||||
| InvalidOffsetFirstLast
|
||||
| NonmonotonicOffsets
|
||||
| NonmonotonicMinorIndices
|
||||
=> SparseFormatError::from_kind_and_error(InvalidStructure, Box::from(err)),
|
||||
MinorIndexOutOfBounds
|
||||
=> SparseFormatError::from_kind_and_error(IndexOutOfBounds,
|
||||
Box::from(err)),
|
||||
PatternDuplicateEntry
|
||||
=> SparseFormatError::from_kind_and_error(SparseFormatErrorKind::DuplicateEntry,
|
||||
Box::from(err)),
|
||||
| NonmonotonicMinorIndices => {
|
||||
SparseFormatError::from_kind_and_error(InvalidStructure, Box::from(err))
|
||||
}
|
||||
MinorIndexOutOfBounds => {
|
||||
SparseFormatError::from_kind_and_error(IndexOutOfBounds, Box::from(err))
|
||||
}
|
||||
PatternDuplicateEntry => SparseFormatError::from_kind_and_error(
|
||||
#[allow(unused_qualifications)]
|
||||
SparseFormatErrorKind::DuplicateEntry,
|
||||
Box::from(err),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -300,22 +306,25 @@ impl fmt::Display for SparsityPatternFormatError {
|
|||
match self {
|
||||
SparsityPatternFormatError::InvalidOffsetArrayLength => {
|
||||
write!(f, "Length of offset array is not equal to (major_dim + 1).")
|
||||
},
|
||||
}
|
||||
SparsityPatternFormatError::InvalidOffsetFirstLast => {
|
||||
write!(f, "First or last offset is incompatible with format.")
|
||||
},
|
||||
}
|
||||
SparsityPatternFormatError::NonmonotonicOffsets => {
|
||||
write!(f, "Offsets are not monotonically increasing.")
|
||||
},
|
||||
}
|
||||
SparsityPatternFormatError::MinorIndexOutOfBounds => {
|
||||
write!(f, "A minor index is out of bounds.")
|
||||
},
|
||||
}
|
||||
SparsityPatternFormatError::DuplicateEntry => {
|
||||
write!(f, "Input data contains duplicate entries.")
|
||||
},
|
||||
}
|
||||
SparsityPatternFormatError::NonmonotonicMinorIndices => {
|
||||
write!(f, "Minor indices are not monotonically increasing within each lane.")
|
||||
},
|
||||
write!(
|
||||
f,
|
||||
"Minor indices are not monotonically increasing within each lane."
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -335,12 +344,12 @@ pub struct SparsityPatternIter<'a> {
|
|||
impl<'a> SparsityPatternIter<'a> {
|
||||
fn from_pattern(pattern: &'a SparsityPattern) -> Self {
|
||||
let first_lane_end = pattern.major_offsets().get(1).unwrap_or(&0);
|
||||
let minors_in_first_lane = &pattern.minor_indices()[0 .. *first_lane_end];
|
||||
let minors_in_first_lane = &pattern.minor_indices()[0..*first_lane_end];
|
||||
Self {
|
||||
major_offsets: pattern.major_offsets(),
|
||||
minor_indices: pattern.minor_indices(),
|
||||
current_lane_idx: 0,
|
||||
remaining_minors_in_lane: minors_in_first_lane
|
||||
remaining_minors_in_lane: minors_in_first_lane,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -374,12 +383,11 @@ impl<'a> Iterator for SparsityPatternIter<'a> {
|
|||
let lower = self.major_offsets[self.current_lane_idx];
|
||||
let upper = self.major_offsets[self.current_lane_idx + 1];
|
||||
if upper > lower {
|
||||
self.remaining_minors_in_lane = &self.minor_indices[(lower + 1) .. upper];
|
||||
return Some((self.current_lane_idx, self.minor_indices[lower]))
|
||||
self.remaining_minors_in_lane = &self.minor_indices[(lower + 1)..upper];
|
||||
return Some((self.current_lane_idx, self.minor_indices[lower]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -11,20 +11,22 @@
|
|||
mod proptest_patched;
|
||||
|
||||
use crate::coo::CooMatrix;
|
||||
use proptest::prelude::*;
|
||||
use proptest::collection::{vec, hash_map, btree_set};
|
||||
use nalgebra::{Scalar, Dim};
|
||||
use std::cmp::min;
|
||||
use std::iter::{repeat};
|
||||
use proptest::sample::{Index};
|
||||
use crate::csc::CscMatrix;
|
||||
use crate::csr::CsrMatrix;
|
||||
use crate::pattern::SparsityPattern;
|
||||
use crate::csc::CscMatrix;
|
||||
use nalgebra::proptest::DimRange;
|
||||
use nalgebra::{Dim, Scalar};
|
||||
use proptest::collection::{btree_set, hash_map, vec};
|
||||
use proptest::prelude::*;
|
||||
use proptest::sample::Index;
|
||||
use std::cmp::min;
|
||||
use std::iter::repeat;
|
||||
|
||||
fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
|
||||
-> impl Strategy<Value=Vec<(usize, usize)>>
|
||||
{
|
||||
fn dense_row_major_coord_strategy(
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
nnz: usize,
|
||||
) -> impl Strategy<Value = Vec<(usize, usize)>> {
|
||||
assert!(nnz <= nrows * ncols);
|
||||
let mut booleans = vec![true; nnz];
|
||||
booleans.append(&mut vec![false; (nrows * ncols) - nnz]);
|
||||
|
@ -38,33 +40,33 @@ fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
|
|||
// // Need to shuffle to make sure they are randomly distributed
|
||||
// .prop_shuffle()
|
||||
|
||||
proptest_patched::Shuffle(Just(booleans))
|
||||
.prop_map(move |booleans| {
|
||||
booleans
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(index, is_entry)| {
|
||||
if is_entry {
|
||||
// Convert linear index to row/col pair
|
||||
let i = index / ncols;
|
||||
let j = index % ncols;
|
||||
Some((i, j))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
proptest_patched::Shuffle(Just(booleans)).prop_map(move |booleans| {
|
||||
booleans
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(index, is_entry)| {
|
||||
if is_entry {
|
||||
// Convert linear index to row/col pair
|
||||
let i = index / ncols;
|
||||
let j = index % ncols;
|
||||
Some((i, j))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}
|
||||
|
||||
/// A strategy for generating `nnz` triplets.
|
||||
///
|
||||
/// This strategy should generally only be used when `nnz` is close to `nrows * ncols`.
|
||||
fn dense_triplet_strategy<T>(value_strategy: T,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
nnz: usize)
|
||||
-> impl Strategy<Value=Vec<(usize, usize, T::Value)>>
|
||||
fn dense_triplet_strategy<T>(
|
||||
value_strategy: T,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
nnz: usize,
|
||||
) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
|
@ -100,15 +102,14 @@ where
|
|||
})
|
||||
// Assign values to each coordinate pair in order to generate a list of triplets
|
||||
.prop_flat_map(move |coords| {
|
||||
vec![value_strategy.clone(); coords.len()]
|
||||
.prop_map(move |values| {
|
||||
coords.clone().into_iter()
|
||||
.zip(values)
|
||||
.map(|((i, j), v)| {
|
||||
(i, j, v)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
vec![value_strategy.clone(); coords.len()].prop_map(move |values| {
|
||||
coords
|
||||
.clone()
|
||||
.into_iter()
|
||||
.zip(values)
|
||||
.map(|((i, j), v)| (i, j, v))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -116,25 +117,23 @@ where
|
|||
///
|
||||
/// This strategy should generally only be used when `nnz << nrows * ncols`. If `nnz` is too
|
||||
/// close to `nrows * ncols` it may fail due to excessive rejected samples.
|
||||
fn sparse_triplet_strategy<T>(value_strategy: T,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
nnz: usize)
|
||||
-> impl Strategy<Value=Vec<(usize, usize, T::Value)>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
fn sparse_triplet_strategy<T>(
|
||||
value_strategy: T,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
nnz: usize,
|
||||
) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
{
|
||||
// Have to handle the zero case: proptest doesn't like empty ranges (i.e. 0 .. 0)
|
||||
let row_index_strategy = if nrows > 0 { 0 .. nrows } else { 0 .. 1 };
|
||||
let col_index_strategy = if ncols > 0 { 0 .. ncols } else { 0 .. 1 };
|
||||
let row_index_strategy = if nrows > 0 { 0..nrows } else { 0..1 };
|
||||
let col_index_strategy = if ncols > 0 { 0..ncols } else { 0..1 };
|
||||
let coord_strategy = (row_index_strategy, col_index_strategy);
|
||||
hash_map(coord_strategy, value_strategy.clone(), nnz)
|
||||
.prop_map(|hash_map| {
|
||||
let triplets: Vec<_> = hash_map
|
||||
.into_iter()
|
||||
.map(|((i, j), v)| (i, j, v))
|
||||
.collect();
|
||||
let triplets: Vec<_> = hash_map.into_iter().map(|((i, j), v)| (i, j, v)).collect();
|
||||
triplets
|
||||
})
|
||||
// Although order in the hash map is unspecified, it's not necessarily *random*
|
||||
|
@ -153,36 +152,41 @@ pub fn coo_no_duplicates<T>(
|
|||
value_strategy: T,
|
||||
rows: impl Into<DimRange>,
|
||||
cols: impl Into<DimRange>,
|
||||
max_nonzeros: usize) -> impl Strategy<Value=CooMatrix<T::Value>>
|
||||
max_nonzeros: usize,
|
||||
) -> impl Strategy<Value = CooMatrix<T::Value>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
{
|
||||
(rows.into().to_range_inclusive(), cols.into().to_range_inclusive())
|
||||
(
|
||||
rows.into().to_range_inclusive(),
|
||||
cols.into().to_range_inclusive(),
|
||||
)
|
||||
.prop_flat_map(move |(nrows, ncols)| {
|
||||
let max_nonzeros = min(max_nonzeros, nrows * ncols);
|
||||
let size_range = 0 ..= max_nonzeros;
|
||||
let size_range = 0..=max_nonzeros;
|
||||
let value_strategy = value_strategy.clone();
|
||||
|
||||
size_range.prop_flat_map(move |nnz| {
|
||||
let value_strategy = value_strategy.clone();
|
||||
if nnz as f64 > 0.10 * (nrows as f64) * (ncols as f64) {
|
||||
// If the number of nnz is sufficiently dense, then use the dense
|
||||
// sample strategy
|
||||
dense_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
|
||||
} else {
|
||||
// Otherwise, use a hash map strategy so that we can get a sparse sampling
|
||||
// (so that complexity is rather on the order of max_nnz than nrows * ncols)
|
||||
sparse_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
|
||||
}
|
||||
})
|
||||
.prop_map(move |triplets| {
|
||||
let mut coo = CooMatrix::new(nrows, ncols);
|
||||
for (i, j, v) in triplets {
|
||||
coo.push(i, j, v);
|
||||
}
|
||||
coo
|
||||
})
|
||||
size_range
|
||||
.prop_flat_map(move |nnz| {
|
||||
let value_strategy = value_strategy.clone();
|
||||
if nnz as f64 > 0.10 * (nrows as f64) * (ncols as f64) {
|
||||
// If the number of nnz is sufficiently dense, then use the dense
|
||||
// sample strategy
|
||||
dense_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
|
||||
} else {
|
||||
// Otherwise, use a hash map strategy so that we can get a sparse sampling
|
||||
// (so that complexity is rather on the order of max_nnz than nrows * ncols)
|
||||
sparse_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
|
||||
}
|
||||
})
|
||||
.prop_map(move |triplets| {
|
||||
let mut coo = CooMatrix::new(nrows, ncols);
|
||||
for (i, j, v) in triplets {
|
||||
coo.push(i, j, v);
|
||||
}
|
||||
coo
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -198,21 +202,22 @@ where
|
|||
/// number of duplicate entries is determined by `max_duplicates`. Note that the matrix might still
|
||||
/// contain explicitly stored zeroes if the value strategy is capable of generating zero values.
|
||||
pub fn coo_with_duplicates<T>(
|
||||
value_strategy: T,
|
||||
rows: impl Into<DimRange>,
|
||||
cols: impl Into<DimRange>,
|
||||
max_nonzeros: usize,
|
||||
max_duplicates: usize)
|
||||
-> impl Strategy<Value=CooMatrix<T::Value>>
|
||||
value_strategy: T,
|
||||
rows: impl Into<DimRange>,
|
||||
cols: impl Into<DimRange>,
|
||||
max_nonzeros: usize,
|
||||
max_duplicates: usize,
|
||||
) -> impl Strategy<Value = CooMatrix<T::Value>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
{
|
||||
let coo_strategy = coo_no_duplicates(value_strategy.clone(), rows, cols, max_nonzeros);
|
||||
let duplicate_strategy = vec((any::<Index>(), value_strategy.clone()), 0 ..= max_duplicates);
|
||||
let duplicate_strategy = vec((any::<Index>(), value_strategy.clone()), 0..=max_duplicates);
|
||||
(coo_strategy, duplicate_strategy)
|
||||
.prop_flat_map(|(coo, duplicates)| {
|
||||
let mut triplets: Vec<(usize, usize, T::Value)> = coo.triplet_iter()
|
||||
let mut triplets: Vec<(usize, usize, T::Value)> = coo
|
||||
.triplet_iter()
|
||||
.map(|(i, j, v)| (i, j, v.clone()))
|
||||
.collect();
|
||||
if !triplets.is_empty() {
|
||||
|
@ -238,9 +243,13 @@ where
|
|||
})
|
||||
}
|
||||
|
||||
fn sparsity_pattern_from_row_major_coords<I>(nmajor: usize, nminor: usize, coords: I) -> SparsityPattern
|
||||
fn sparsity_pattern_from_row_major_coords<I>(
|
||||
nmajor: usize,
|
||||
nminor: usize,
|
||||
coords: I,
|
||||
) -> SparsityPattern
|
||||
where
|
||||
I: Iterator<Item=(usize, usize)> + ExactSizeIterator,
|
||||
I: Iterator<Item = (usize, usize)> + ExactSizeIterator,
|
||||
{
|
||||
let mut minors = Vec::with_capacity(coords.len());
|
||||
let mut offsets = Vec::with_capacity(nmajor + 1);
|
||||
|
@ -248,8 +257,11 @@ where
|
|||
offsets.push(0);
|
||||
for (idx, (i, j)) in coords.enumerate() {
|
||||
assert!(i >= current_major);
|
||||
assert!(i < nmajor && j < nminor, "Generated coords are out of bounds");
|
||||
while current_major < i{
|
||||
assert!(
|
||||
i < nmajor && j < nminor,
|
||||
"Generated coords are out of bounds"
|
||||
);
|
||||
while current_major < i {
|
||||
offsets.push(idx);
|
||||
current_major += 1;
|
||||
}
|
||||
|
@ -264,10 +276,7 @@ where
|
|||
assert_eq!(offsets.first().unwrap(), &0);
|
||||
assert_eq!(offsets.len(), nmajor + 1);
|
||||
|
||||
SparsityPattern::try_from_offsets_and_indices(nmajor,
|
||||
nminor,
|
||||
offsets,
|
||||
minors)
|
||||
SparsityPattern::try_from_offsets_and_indices(nmajor, nminor, offsets, minors)
|
||||
.expect("Internal error: Generated sparsity pattern is invalid")
|
||||
}
|
||||
|
||||
|
@ -275,14 +284,17 @@ where
|
|||
pub fn sparsity_pattern(
|
||||
major_lanes: impl Into<DimRange>,
|
||||
minor_lanes: impl Into<DimRange>,
|
||||
max_nonzeros: usize)
|
||||
-> impl Strategy<Value=SparsityPattern>
|
||||
{
|
||||
(major_lanes.into().to_range_inclusive(), minor_lanes.into().to_range_inclusive())
|
||||
max_nonzeros: usize,
|
||||
) -> impl Strategy<Value = SparsityPattern> {
|
||||
(
|
||||
major_lanes.into().to_range_inclusive(),
|
||||
minor_lanes.into().to_range_inclusive(),
|
||||
)
|
||||
.prop_flat_map(move |(nmajor, nminor)| {
|
||||
let max_nonzeros = min(nmajor * nminor, max_nonzeros);
|
||||
(Just(nmajor), Just(nminor), 0 ..= max_nonzeros)
|
||||
}).prop_flat_map(move |(nmajor, nminor, nnz)| {
|
||||
(Just(nmajor), Just(nminor), 0..=max_nonzeros)
|
||||
})
|
||||
.prop_flat_map(move |(nmajor, nminor, nnz)| {
|
||||
if 10 * nnz < nmajor * nminor {
|
||||
// If nnz is small compared to a dense matrix, then use a sparse sampling strategy
|
||||
btree_set((0..nmajor, 0..nminor), nnz)
|
||||
|
@ -297,55 +309,66 @@ pub fn sparsity_pattern(
|
|||
.prop_map(move |coords| {
|
||||
let coords = coords.into_iter();
|
||||
sparsity_pattern_from_row_major_coords(nmajor, nminor, coords)
|
||||
}).boxed()
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// A strategy for generating CSR matrices.
|
||||
pub fn csr<T>(value_strategy: T,
|
||||
rows: impl Into<DimRange>,
|
||||
cols: impl Into<DimRange>,
|
||||
max_nonzeros: usize)
|
||||
-> impl Strategy<Value=CsrMatrix<T::Value>>
|
||||
pub fn csr<T>(
|
||||
value_strategy: T,
|
||||
rows: impl Into<DimRange>,
|
||||
cols: impl Into<DimRange>,
|
||||
max_nonzeros: usize,
|
||||
) -> impl Strategy<Value = CsrMatrix<T::Value>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
{
|
||||
let rows = rows.into();
|
||||
let cols = cols.into();
|
||||
sparsity_pattern(rows.lower_bound().value() ..= rows.upper_bound().value(), cols.lower_bound().value() ..= cols.upper_bound().value(), max_nonzeros)
|
||||
.prop_flat_map(move |pattern| {
|
||||
let nnz = pattern.nnz();
|
||||
let values = vec![value_strategy.clone(); nnz];
|
||||
(Just(pattern), values)
|
||||
})
|
||||
.prop_map(|(pattern, values)| {
|
||||
CsrMatrix::try_from_pattern_and_values(pattern, values)
|
||||
.expect("Internal error: Generated CsrMatrix is invalid")
|
||||
})
|
||||
sparsity_pattern(
|
||||
rows.lower_bound().value()..=rows.upper_bound().value(),
|
||||
cols.lower_bound().value()..=cols.upper_bound().value(),
|
||||
max_nonzeros,
|
||||
)
|
||||
.prop_flat_map(move |pattern| {
|
||||
let nnz = pattern.nnz();
|
||||
let values = vec![value_strategy.clone(); nnz];
|
||||
(Just(pattern), values)
|
||||
})
|
||||
.prop_map(|(pattern, values)| {
|
||||
CsrMatrix::try_from_pattern_and_values(pattern, values)
|
||||
.expect("Internal error: Generated CsrMatrix is invalid")
|
||||
})
|
||||
}
|
||||
|
||||
/// A strategy for generating CSC matrices.
|
||||
pub fn csc<T>(value_strategy: T,
|
||||
rows: impl Into<DimRange>,
|
||||
cols: impl Into<DimRange>,
|
||||
max_nonzeros: usize)
|
||||
-> impl Strategy<Value=CscMatrix<T::Value>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
pub fn csc<T>(
|
||||
value_strategy: T,
|
||||
rows: impl Into<DimRange>,
|
||||
cols: impl Into<DimRange>,
|
||||
max_nonzeros: usize,
|
||||
) -> impl Strategy<Value = CscMatrix<T::Value>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
{
|
||||
let rows = rows.into();
|
||||
let cols = cols.into();
|
||||
sparsity_pattern(cols.lower_bound().value() ..= cols.upper_bound().value(), rows.lower_bound().value() ..= rows.upper_bound().value(), max_nonzeros)
|
||||
.prop_flat_map(move |pattern| {
|
||||
let nnz = pattern.nnz();
|
||||
let values = vec![value_strategy.clone(); nnz];
|
||||
(Just(pattern), values)
|
||||
})
|
||||
.prop_map(|(pattern, values)| {
|
||||
CscMatrix::try_from_pattern_and_values(pattern, values)
|
||||
.expect("Internal error: Generated CscMatrix is invalid")
|
||||
})
|
||||
sparsity_pattern(
|
||||
cols.lower_bound().value()..=cols.upper_bound().value(),
|
||||
rows.lower_bound().value()..=rows.upper_bound().value(),
|
||||
max_nonzeros,
|
||||
)
|
||||
.prop_flat_map(move |pattern| {
|
||||
let nnz = pattern.nnz();
|
||||
let values = vec![value_strategy.clone(); nnz];
|
||||
(Just(pattern), values)
|
||||
})
|
||||
.prop_map(|(pattern, values)| {
|
||||
CscMatrix::try_from_pattern_and_values(pattern, values)
|
||||
.expect("Internal error: Generated CscMatrix is invalid")
|
||||
})
|
||||
}
|
|
@ -22,19 +22,19 @@
|
|||
|
||||
*/
|
||||
|
||||
use proptest::strategy::{Strategy, Shuffleable, NewTree, ValueTree};
|
||||
use proptest::test_runner::{TestRunner, TestRng};
|
||||
use std::cell::Cell;
|
||||
use proptest::num;
|
||||
use proptest::prelude::Rng;
|
||||
use proptest::strategy::{NewTree, Shuffleable, Strategy, ValueTree};
|
||||
use proptest::test_runner::{TestRng, TestRunner};
|
||||
use std::cell::Cell;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[must_use = "strategies do nothing unless used"]
|
||||
pub struct Shuffle<S>(pub(super) S);
|
||||
|
||||
impl<S: Strategy> Strategy for Shuffle<S>
|
||||
where
|
||||
S::Value: Shuffleable,
|
||||
where
|
||||
S::Value: Shuffleable,
|
||||
{
|
||||
type Tree = ShuffleValueTree<S::Tree>;
|
||||
type Value = S::Value;
|
||||
|
@ -60,8 +60,8 @@ pub struct ShuffleValueTree<V> {
|
|||
}
|
||||
|
||||
impl<V: ValueTree> ShuffleValueTree<V>
|
||||
where
|
||||
V::Value: Shuffleable,
|
||||
where
|
||||
V::Value: Shuffleable,
|
||||
{
|
||||
fn init_dist(&self, dflt: usize) -> usize {
|
||||
if self.dist.get().is_none() {
|
||||
|
@ -79,8 +79,8 @@ impl<V: ValueTree> ShuffleValueTree<V>
|
|||
}
|
||||
|
||||
impl<V: ValueTree> ValueTree for ShuffleValueTree<V>
|
||||
where
|
||||
V::Value: Shuffleable,
|
||||
where
|
||||
V::Value: Shuffleable,
|
||||
{
|
||||
type Value = V::Value;
|
||||
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
use proptest::strategy::Strategy;
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::proptest::{csr, csc};
|
||||
use nalgebra_sparse::csc::CscMatrix;
|
||||
use std::ops::RangeInclusive;
|
||||
use std::convert::{TryFrom};
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::proptest::{csc, csr};
|
||||
use proptest::strategy::Strategy;
|
||||
use std::convert::TryFrom;
|
||||
use std::fmt::Debug;
|
||||
use std::ops::RangeInclusive;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! assert_panics {
|
||||
($e:expr) => {{
|
||||
use std::panic::{catch_unwind};
|
||||
use std::panic::catch_unwind;
|
||||
use std::stringify;
|
||||
let expr_string = stringify!($e);
|
||||
|
||||
|
@ -22,37 +22,56 @@ macro_rules! assert_panics {
|
|||
|
||||
let result = catch_unwind(|| $e);
|
||||
if result.is_ok() {
|
||||
panic!("assert_panics!({}) failed: the expression did not panic.", expr_string);
|
||||
panic!(
|
||||
"assert_panics!({}) failed: the expression did not panic.",
|
||||
expr_string
|
||||
);
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
pub const PROPTEST_MATRIX_DIM: RangeInclusive<usize> = 0..=6;
|
||||
pub const PROPTEST_MAX_NNZ: usize = 40;
|
||||
pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5 ..= 5;
|
||||
pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5..=5;
|
||||
|
||||
pub fn value_strategy<T>() -> RangeInclusive<T>
|
||||
where
|
||||
T: TryFrom<i32>,
|
||||
T::Error: Debug
|
||||
T::Error: Debug,
|
||||
{
|
||||
let (start, end) = (PROPTEST_I32_VALUE_STRATEGY.start(), PROPTEST_I32_VALUE_STRATEGY.end());
|
||||
T::try_from(*start).unwrap() ..= T::try_from(*end).unwrap()
|
||||
let (start, end) = (
|
||||
PROPTEST_I32_VALUE_STRATEGY.start(),
|
||||
PROPTEST_I32_VALUE_STRATEGY.end(),
|
||||
);
|
||||
T::try_from(*start).unwrap()..=T::try_from(*end).unwrap()
|
||||
}
|
||||
|
||||
pub fn non_zero_i32_value_strategy() -> impl Strategy<Value=i32> {
|
||||
let (start, end) = (PROPTEST_I32_VALUE_STRATEGY.start(), PROPTEST_I32_VALUE_STRATEGY.end());
|
||||
pub fn non_zero_i32_value_strategy() -> impl Strategy<Value = i32> {
|
||||
let (start, end) = (
|
||||
PROPTEST_I32_VALUE_STRATEGY.start(),
|
||||
PROPTEST_I32_VALUE_STRATEGY.end(),
|
||||
);
|
||||
assert!(start < &0);
|
||||
assert!(end > &0);
|
||||
// Note: we don't use RangeInclusive for the second range, because then we'd have different
|
||||
// types, which would require boxing
|
||||
(*start .. 0).prop_union(1 .. *end + 1)
|
||||
(*start..0).prop_union(1..*end + 1)
|
||||
}
|
||||
|
||||
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
||||
csr(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
||||
pub fn csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
|
||||
csr(
|
||||
PROPTEST_I32_VALUE_STRATEGY,
|
||||
PROPTEST_MATRIX_DIM,
|
||||
PROPTEST_MATRIX_DIM,
|
||||
PROPTEST_MAX_NNZ,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn csc_strategy() -> impl Strategy<Value=CscMatrix<i32>> {
|
||||
csc(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
||||
pub fn csc_strategy() -> impl Strategy<Value = CscMatrix<i32>> {
|
||||
csc(
|
||||
PROPTEST_I32_VALUE_STRATEGY,
|
||||
PROPTEST_MATRIX_DIM,
|
||||
PROPTEST_MATRIX_DIM,
|
||||
PROPTEST_MAX_NNZ,
|
||||
)
|
||||
}
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
use nalgebra_sparse::coo::CooMatrix;
|
||||
use nalgebra_sparse::convert::serial::{convert_coo_dense, convert_coo_csr,
|
||||
convert_dense_coo, convert_csr_dense,
|
||||
convert_csr_coo, convert_dense_csr,
|
||||
convert_csc_coo, convert_coo_csc,
|
||||
convert_csc_dense, convert_dense_csc,
|
||||
convert_csr_csc, convert_csc_csr};
|
||||
use nalgebra_sparse::proptest::{coo_with_duplicates, coo_no_duplicates, csr, csc};
|
||||
use nalgebra::proptest::matrix;
|
||||
use proptest::prelude::*;
|
||||
use nalgebra::DMatrix;
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::csc::CscMatrix;
|
||||
use crate::common::csc_strategy;
|
||||
use nalgebra::proptest::matrix;
|
||||
use nalgebra::DMatrix;
|
||||
use nalgebra_sparse::convert::serial::{
|
||||
convert_coo_csc, convert_coo_csr, convert_coo_dense, convert_csc_coo, convert_csc_csr,
|
||||
convert_csc_dense, convert_csr_coo, convert_csr_csc, convert_csr_dense, convert_dense_coo,
|
||||
convert_dense_csc, convert_dense_csr,
|
||||
};
|
||||
use nalgebra_sparse::coo::CooMatrix;
|
||||
use nalgebra_sparse::csc::CscMatrix;
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::proptest::{coo_no_duplicates, coo_with_duplicates, csc, csr};
|
||||
use proptest::prelude::*;
|
||||
|
||||
#[test]
|
||||
fn test_convert_dense_coo() {
|
||||
|
@ -41,16 +40,17 @@ fn test_convert_dense_coo() {
|
|||
// Here we implicitly test that the coo matrix is indeed constructed from column-major
|
||||
// iteration of the dense matrix.
|
||||
let dense = DMatrix::from_row_slice(2, 3, entries);
|
||||
let coo_no_dup = CooMatrix::try_from_triplets(2, 3,
|
||||
vec![0, 1, 0],
|
||||
vec![0, 1, 2],
|
||||
vec![1, 5, 3])
|
||||
.unwrap();
|
||||
let coo_dup = CooMatrix::try_from_triplets(2, 3,
|
||||
vec![0, 1, 0, 1],
|
||||
vec![0, 1, 2, 1],
|
||||
vec![1, -2, 3, 7])
|
||||
.unwrap();
|
||||
let coo_no_dup =
|
||||
CooMatrix::try_from_triplets(2, 3, vec![0, 1, 0], vec![0, 1, 2], vec![1, 5, 3])
|
||||
.unwrap();
|
||||
let coo_dup = CooMatrix::try_from_triplets(
|
||||
2,
|
||||
3,
|
||||
vec![0, 1, 0, 1],
|
||||
vec![0, 1, 2, 1],
|
||||
vec![1, -2, 3, 7],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(CooMatrix::from(&dense), coo_no_dup);
|
||||
assert_eq!(DMatrix::from(&coo_dup), dense);
|
||||
|
@ -76,8 +76,9 @@ fn test_convert_coo_csr() {
|
|||
4,
|
||||
vec![0, 1, 2, 5],
|
||||
vec![1, 3, 0, 2, 3],
|
||||
vec![2, 4, 1, 1, 2]
|
||||
).unwrap();
|
||||
vec![2, 4, 1, 1, 2],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(convert_coo_csr(&coo), expected_csr);
|
||||
}
|
||||
|
@ -101,8 +102,9 @@ fn test_convert_coo_csr() {
|
|||
4,
|
||||
vec![0, 1, 2, 5],
|
||||
vec![1, 3, 0, 2, 3],
|
||||
vec![5, 4, 1, 1, 4]
|
||||
).unwrap();
|
||||
vec![5, 4, 1, 1, 4],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(convert_coo_csr(&coo), expected_csr);
|
||||
}
|
||||
|
@ -115,16 +117,18 @@ fn test_convert_csr_coo() {
|
|||
4,
|
||||
vec![0, 1, 2, 5],
|
||||
vec![1, 3, 0, 2, 3],
|
||||
vec![5, 4, 1, 1, 4]
|
||||
).unwrap();
|
||||
vec![5, 4, 1, 1, 4],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let expected_coo = CooMatrix::try_from_triplets(
|
||||
3,
|
||||
4,
|
||||
vec![0, 1, 2, 2, 2],
|
||||
vec![1, 3, 0, 2, 3],
|
||||
vec![5, 4, 1, 1, 4]
|
||||
).unwrap();
|
||||
vec![5, 4, 1, 1, 4],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(convert_csr_coo(&csr), expected_coo);
|
||||
}
|
||||
|
@ -148,8 +152,9 @@ fn test_convert_coo_csc() {
|
|||
4,
|
||||
vec![0, 1, 2, 3, 5],
|
||||
vec![2, 0, 2, 1, 2],
|
||||
vec![1, 2, 1, 4, 2]
|
||||
).unwrap();
|
||||
vec![1, 2, 1, 4, 2],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(convert_coo_csc(&coo), expected_csc);
|
||||
}
|
||||
|
@ -173,8 +178,9 @@ fn test_convert_coo_csc() {
|
|||
4,
|
||||
vec![0, 1, 2, 3, 5],
|
||||
vec![2, 0, 2, 1, 2],
|
||||
vec![1, 5, 1, 4, 4]
|
||||
).unwrap();
|
||||
vec![1, 5, 1, 4, 4],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(convert_coo_csc(&coo), expected_csc);
|
||||
}
|
||||
|
@ -187,16 +193,18 @@ fn test_convert_csc_coo() {
|
|||
4,
|
||||
vec![0, 1, 2, 3, 5],
|
||||
vec![2, 0, 2, 1, 2],
|
||||
vec![1, 2, 1, 4, 2]
|
||||
).unwrap();
|
||||
vec![1, 2, 1, 4, 2],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let expected_coo = CooMatrix::try_from_triplets(
|
||||
3,
|
||||
4,
|
||||
vec![2, 0, 2, 1, 2],
|
||||
vec![0, 1, 2, 3, 3],
|
||||
vec![1, 2, 1, 4, 2]
|
||||
).unwrap();
|
||||
vec![1, 2, 1, 4, 2],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(convert_csc_coo(&csc), expected_coo);
|
||||
}
|
||||
|
@ -209,7 +217,8 @@ fn test_convert_csr_csc_bidirectional() {
|
|||
vec![0, 3, 4, 6],
|
||||
vec![1, 2, 3, 0, 1, 3],
|
||||
vec![5, 3, 2, 2, 1, 4],
|
||||
).unwrap();
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let csc = CscMatrix::try_from_csc_data(
|
||||
3,
|
||||
|
@ -217,7 +226,8 @@ fn test_convert_csr_csc_bidirectional() {
|
|||
vec![0, 1, 3, 4, 6],
|
||||
vec![1, 0, 2, 0, 0, 2],
|
||||
vec![2, 5, 1, 3, 2, 4],
|
||||
).unwrap();
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(convert_csr_csc(&csr), csc);
|
||||
assert_eq!(convert_csc_csr(&csc), csr);
|
||||
|
@ -231,7 +241,8 @@ fn test_convert_csr_dense_bidirectional() {
|
|||
vec![0, 3, 4, 6],
|
||||
vec![1, 2, 3, 0, 1, 3],
|
||||
vec![5, 3, 2, 2, 1, 4],
|
||||
).unwrap();
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
#[rustfmt::skip]
|
||||
let dense = DMatrix::from_row_slice(3, 4, &[
|
||||
|
@ -252,7 +263,8 @@ fn test_convert_csc_dense_bidirectional() {
|
|||
vec![0, 1, 3, 4, 6],
|
||||
vec![1, 0, 2, 0, 0, 2],
|
||||
vec![2, 5, 1, 3, 2, 4],
|
||||
).unwrap();
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
#[rustfmt::skip]
|
||||
let dense = DMatrix::from_row_slice(3, 4, &[
|
||||
|
@ -265,29 +277,29 @@ fn test_convert_csc_dense_bidirectional() {
|
|||
assert_eq!(convert_dense_csc(&dense), csc);
|
||||
}
|
||||
|
||||
fn coo_strategy() -> impl Strategy<Value=CooMatrix<i32>> {
|
||||
coo_with_duplicates(-5 ..= 5, 0..=6usize, 0..=6usize, 40, 2)
|
||||
fn coo_strategy() -> impl Strategy<Value = CooMatrix<i32>> {
|
||||
coo_with_duplicates(-5..=5, 0..=6usize, 0..=6usize, 40, 2)
|
||||
}
|
||||
|
||||
fn coo_no_duplicates_strategy() -> impl Strategy<Value=CooMatrix<i32>> {
|
||||
coo_no_duplicates(-5 ..= 5, 0..=6usize, 0..=6usize, 40)
|
||||
fn coo_no_duplicates_strategy() -> impl Strategy<Value = CooMatrix<i32>> {
|
||||
coo_no_duplicates(-5..=5, 0..=6usize, 0..=6usize, 40)
|
||||
}
|
||||
|
||||
fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
||||
csr(-5 ..= 5, 0..=6usize, 0..=6usize, 40)
|
||||
fn csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
|
||||
csr(-5..=5, 0..=6usize, 0..=6usize, 40)
|
||||
}
|
||||
|
||||
/// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns
|
||||
fn non_zero_csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
||||
csr(1 ..= 5, 0..=6usize, 0..=6usize, 40)
|
||||
fn non_zero_csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
|
||||
csr(1..=5, 0..=6usize, 0..=6usize, 40)
|
||||
}
|
||||
|
||||
/// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns
|
||||
fn non_zero_csc_strategy() -> impl Strategy<Value=CscMatrix<i32>> {
|
||||
csc(1 ..= 5, 0..=6usize, 0..=6usize, 40)
|
||||
fn non_zero_csc_strategy() -> impl Strategy<Value = CscMatrix<i32>> {
|
||||
csc(1..=5, 0..=6usize, 0..=6usize, 40)
|
||||
}
|
||||
|
||||
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
||||
fn dense_strategy() -> impl Strategy<Value = DMatrix<i32>> {
|
||||
matrix(-5..=5, 0..=6, 0..=6)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use nalgebra_sparse::{SparseFormatErrorKind};
|
||||
use nalgebra_sparse::coo::CooMatrix;
|
||||
use nalgebra::DMatrix;
|
||||
use crate::assert_panics;
|
||||
use nalgebra::DMatrix;
|
||||
use nalgebra_sparse::coo::CooMatrix;
|
||||
use nalgebra_sparse::SparseFormatErrorKind;
|
||||
|
||||
#[test]
|
||||
fn coo_construction_for_valid_data() {
|
||||
|
@ -10,8 +10,8 @@ fn coo_construction_for_valid_data() {
|
|||
|
||||
{
|
||||
// Zero matrix
|
||||
let coo = CooMatrix::<i32>::try_from_triplets(3, 2, Vec::new(), Vec::new(), Vec::new())
|
||||
.unwrap();
|
||||
let coo =
|
||||
CooMatrix::<i32>::try_from_triplets(3, 2, Vec::new(), Vec::new(), Vec::new()).unwrap();
|
||||
assert_eq!(coo.nrows(), 3);
|
||||
assert_eq!(coo.ncols(), 2);
|
||||
assert!(coo.triplet_iter().next().is_none());
|
||||
|
@ -27,8 +27,8 @@ fn coo_construction_for_valid_data() {
|
|||
let i = vec![0, 1, 0, 0, 2];
|
||||
let j = vec![0, 2, 1, 3, 3];
|
||||
let v = vec![2, 3, 7, 3, 1];
|
||||
let coo = CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone())
|
||||
.unwrap();
|
||||
let coo =
|
||||
CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap();
|
||||
assert_eq!(coo.nrows(), 3);
|
||||
assert_eq!(coo.ncols(), 5);
|
||||
|
||||
|
@ -59,8 +59,8 @@ fn coo_construction_for_valid_data() {
|
|||
let i = vec![0, 1, 0, 0, 0, 0, 2, 1];
|
||||
let j = vec![0, 2, 0, 1, 0, 3, 3, 2];
|
||||
let v = vec![2, 3, 4, 7, 1, 3, 1, 5];
|
||||
let coo = CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone())
|
||||
.unwrap();
|
||||
let coo =
|
||||
CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap();
|
||||
assert_eq!(coo.nrows(), 3);
|
||||
assert_eq!(coo.ncols(), 5);
|
||||
|
||||
|
@ -92,25 +92,37 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
|
|||
{
|
||||
// 0x0 matrix
|
||||
let result = CooMatrix::<i32>::try_from_triplets(0, 0, vec![0], vec![0], vec![2]);
|
||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
||||
assert!(matches!(
|
||||
result.unwrap_err().kind(),
|
||||
SparseFormatErrorKind::IndexOutOfBounds
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
// 1x1 matrix, row out of bounds
|
||||
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![0], vec![2]);
|
||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
||||
assert!(matches!(
|
||||
result.unwrap_err().kind(),
|
||||
SparseFormatErrorKind::IndexOutOfBounds
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
// 1x1 matrix, col out of bounds
|
||||
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![0], vec![1], vec![2]);
|
||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
||||
assert!(matches!(
|
||||
result.unwrap_err().kind(),
|
||||
SparseFormatErrorKind::IndexOutOfBounds
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
// 1x1 matrix, row and col out of bounds
|
||||
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![1], vec![2]);
|
||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
||||
assert!(matches!(
|
||||
result.unwrap_err().kind(),
|
||||
SparseFormatErrorKind::IndexOutOfBounds
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -119,7 +131,10 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
|
|||
let j = vec![0, 2, 1, 3, 3];
|
||||
let v = vec![2, 3, 7, 3, 1];
|
||||
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
||||
assert!(matches!(
|
||||
result.unwrap_err().kind(),
|
||||
SparseFormatErrorKind::IndexOutOfBounds
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -128,7 +143,10 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
|
|||
let j = vec![0, 2, 1, 5, 3];
|
||||
let v = vec![2, 3, 7, 3, 1];
|
||||
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
||||
assert!(matches!(
|
||||
result.unwrap_err().kind(),
|
||||
SparseFormatErrorKind::IndexOutOfBounds
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -137,16 +155,55 @@ fn coo_try_from_triplets_panics_on_mismatched_vectors() {
|
|||
// Check that try_from_triplets panics when the triplet vectors have different lengths
|
||||
macro_rules! assert_errs {
|
||||
($result:expr) => {
|
||||
assert!(matches!($result.unwrap_err().kind(), SparseFormatErrorKind::InvalidStructure))
|
||||
}
|
||||
assert!(matches!(
|
||||
$result.unwrap_err().kind(),
|
||||
SparseFormatErrorKind::InvalidStructure
|
||||
))
|
||||
};
|
||||
}
|
||||
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1, 2], vec![0], vec![0]));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1], vec![0, 0], vec![0]));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1], vec![0], vec![0, 1]));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1, 2], vec![0, 1], vec![0]));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1], vec![0, 1], vec![0, 1]));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(3, 5, vec![1, 1], vec![0], vec![0, 1]));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||
3,
|
||||
5,
|
||||
vec![1, 2],
|
||||
vec![0],
|
||||
vec![0]
|
||||
));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||
3,
|
||||
5,
|
||||
vec![1],
|
||||
vec![0, 0],
|
||||
vec![0]
|
||||
));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||
3,
|
||||
5,
|
||||
vec![1],
|
||||
vec![0],
|
||||
vec![0, 1]
|
||||
));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||
3,
|
||||
5,
|
||||
vec![1, 2],
|
||||
vec![0, 1],
|
||||
vec![0]
|
||||
));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||
3,
|
||||
5,
|
||||
vec![1],
|
||||
vec![0, 1],
|
||||
vec![0, 1]
|
||||
));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||
3,
|
||||
5,
|
||||
vec![1, 1],
|
||||
vec![0],
|
||||
vec![0, 1]
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -157,10 +214,16 @@ fn coo_push_valid_entries() {
|
|||
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1)]);
|
||||
|
||||
coo.push(0, 0, 2);
|
||||
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1), (0, 0, &2)]);
|
||||
assert_eq!(
|
||||
coo.triplet_iter().collect::<Vec<_>>(),
|
||||
vec![(0, 0, &1), (0, 0, &2)]
|
||||
);
|
||||
|
||||
coo.push(2, 2, 3);
|
||||
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1), (0, 0, &2), (2, 2, &3)]);
|
||||
assert_eq!(
|
||||
coo.triplet_iter().collect::<Vec<_>>(),
|
||||
vec![(0, 0, &1), (0, 0, &2), (2, 2, &3)]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use nalgebra::DMatrix;
|
||||
use nalgebra_sparse::csc::CscMatrix;
|
||||
use nalgebra_sparse::SparseFormatErrorKind;
|
||||
use nalgebra::DMatrix;
|
||||
|
||||
use proptest::prelude::*;
|
||||
use proptest::sample::subsequence;
|
||||
|
@ -42,7 +42,10 @@ fn csc_matrix_valid_data() {
|
|||
assert_eq!(matrix.col_mut(0).row_indices(), &[]);
|
||||
assert_eq!(matrix.col_mut(0).values(), &[]);
|
||||
assert_eq!(matrix.col_mut(0).values_mut(), &[]);
|
||||
assert_eq!(matrix.col_mut(0).rows_and_values_mut(), ([].as_ref(), [].as_mut()));
|
||||
assert_eq!(
|
||||
matrix.col_mut(0).rows_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.col(1).nrows(), 2);
|
||||
assert_eq!(matrix.col(1).nnz(), 0);
|
||||
|
@ -53,7 +56,10 @@ fn csc_matrix_valid_data() {
|
|||
assert_eq!(matrix.col_mut(1).row_indices(), &[]);
|
||||
assert_eq!(matrix.col_mut(1).values(), &[]);
|
||||
assert_eq!(matrix.col_mut(1).values_mut(), &[]);
|
||||
assert_eq!(matrix.col_mut(1).rows_and_values_mut(), ([].as_ref(), [].as_mut()));
|
||||
assert_eq!(
|
||||
matrix.col_mut(1).rows_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.col(2).nrows(), 2);
|
||||
assert_eq!(matrix.col(2).nnz(), 0);
|
||||
|
@ -64,7 +70,10 @@ fn csc_matrix_valid_data() {
|
|||
assert_eq!(matrix.col_mut(2).row_indices(), &[]);
|
||||
assert_eq!(matrix.col_mut(2).values(), &[]);
|
||||
assert_eq!(matrix.col_mut(2).values_mut(), &[]);
|
||||
assert_eq!(matrix.col_mut(2).rows_and_values_mut(), ([].as_ref(), [].as_mut()));
|
||||
assert_eq!(
|
||||
matrix.col_mut(2).rows_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert!(matrix.get_col(3).is_none());
|
||||
assert!(matrix.get_col_mut(3).is_none());
|
||||
|
@ -81,11 +90,9 @@ fn csc_matrix_valid_data() {
|
|||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let mut matrix = CscMatrix::try_from_csc_data(6,
|
||||
3,
|
||||
offsets.clone(),
|
||||
indices.clone(),
|
||||
values.clone()).unwrap();
|
||||
let mut matrix =
|
||||
CscMatrix::try_from_csc_data(6, 3, offsets.clone(), indices.clone(), values.clone())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(matrix.nrows(), 6);
|
||||
assert_eq!(matrix.ncols(), 3);
|
||||
|
@ -95,10 +102,20 @@ fn csc_matrix_valid_data() {
|
|||
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
|
||||
|
||||
let expected_triplets = vec![(0, 0, 0), (5, 0, 1), (1, 2, 2), (2, 2, 3), (3, 2, 4)];
|
||||
assert_eq!(matrix.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(),
|
||||
expected_triplets);
|
||||
assert_eq!(matrix.triplet_iter_mut().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(),
|
||||
expected_triplets);
|
||||
assert_eq!(
|
||||
matrix
|
||||
.triplet_iter()
|
||||
.map(|(i, j, v)| (i, j, *v))
|
||||
.collect::<Vec<_>>(),
|
||||
expected_triplets
|
||||
);
|
||||
assert_eq!(
|
||||
matrix
|
||||
.triplet_iter_mut()
|
||||
.map(|(i, j, v)| (i, j, *v))
|
||||
.collect::<Vec<_>>(),
|
||||
expected_triplets
|
||||
);
|
||||
|
||||
assert_eq!(matrix.col(0).nrows(), 6);
|
||||
assert_eq!(matrix.col(0).nnz(), 2);
|
||||
|
@ -109,7 +126,10 @@ fn csc_matrix_valid_data() {
|
|||
assert_eq!(matrix.col_mut(0).row_indices(), &[0, 5]);
|
||||
assert_eq!(matrix.col_mut(0).values(), &[0, 1]);
|
||||
assert_eq!(matrix.col_mut(0).values_mut(), &[0, 1]);
|
||||
assert_eq!(matrix.col_mut(0).rows_and_values_mut(), ([0, 5].as_ref(), [0, 1].as_mut()));
|
||||
assert_eq!(
|
||||
matrix.col_mut(0).rows_and_values_mut(),
|
||||
([0, 5].as_ref(), [0, 1].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.col(1).nrows(), 6);
|
||||
assert_eq!(matrix.col(1).nnz(), 0);
|
||||
|
@ -120,7 +140,10 @@ fn csc_matrix_valid_data() {
|
|||
assert_eq!(matrix.col_mut(1).row_indices(), &[]);
|
||||
assert_eq!(matrix.col_mut(1).values(), &[]);
|
||||
assert_eq!(matrix.col_mut(1).values_mut(), &[]);
|
||||
assert_eq!(matrix.col_mut(1).rows_and_values_mut(), ([].as_ref(), [].as_mut()));
|
||||
assert_eq!(
|
||||
matrix.col_mut(1).rows_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.col(2).nrows(), 6);
|
||||
assert_eq!(matrix.col(2).nnz(), 3);
|
||||
|
@ -131,7 +154,10 @@ fn csc_matrix_valid_data() {
|
|||
assert_eq!(matrix.col_mut(2).row_indices(), &[1, 2, 3]);
|
||||
assert_eq!(matrix.col_mut(2).values(), &[2, 3, 4]);
|
||||
assert_eq!(matrix.col_mut(2).values_mut(), &[2, 3, 4]);
|
||||
assert_eq!(matrix.col_mut(2).rows_and_values_mut(), ([1, 2, 3].as_ref(), [2, 3, 4].as_mut()));
|
||||
assert_eq!(
|
||||
matrix.col_mut(2).rows_and_values_mut(),
|
||||
([1, 2, 3].as_ref(), [2, 3, 4].as_mut())
|
||||
);
|
||||
|
||||
assert!(matrix.get_col(3).is_none());
|
||||
assert!(matrix.get_col_mut(3).is_none());
|
||||
|
@ -146,11 +172,13 @@ fn csc_matrix_valid_data() {
|
|||
|
||||
#[test]
|
||||
fn csc_matrix_try_from_invalid_csc_data() {
|
||||
|
||||
{
|
||||
// Empty offset array (invalid length)
|
||||
let matrix = CscMatrix::try_from_csc_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -160,7 +188,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
|||
let values = vec![0, 1, 2, 3, 4];
|
||||
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -169,7 +200,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
|||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -178,7 +212,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
|||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -187,7 +224,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
|||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -196,7 +236,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
|||
let indices = vec![0, 1, 2, 3, 4];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -205,7 +248,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
|||
let indices = vec![0, 2, 3, 1, 4];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -214,7 +260,10 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
|||
let indices = vec![0, 6, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::IndexOutOfBounds);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::IndexOutOfBounds
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -223,9 +272,11 @@ fn csc_matrix_try_from_invalid_csc_data() {
|
|||
let indices = vec![0, 5, 2, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::DuplicateEntry);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::DuplicateEntry
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -239,11 +290,7 @@ fn csc_disassemble_avoids_clone_when_owned() {
|
|||
let offsets_ptr = offsets.as_ptr();
|
||||
let indices_ptr = indices.as_ptr();
|
||||
let values_ptr = values.as_ptr();
|
||||
let matrix = CscMatrix::try_from_csc_data(6,
|
||||
3,
|
||||
offsets,
|
||||
indices,
|
||||
values).unwrap();
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values).unwrap();
|
||||
|
||||
let (offsets, indices, values) = matrix.disassemble();
|
||||
assert_eq!(offsets.as_ptr(), offsets_ptr);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use nalgebra::DMatrix;
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::SparseFormatErrorKind;
|
||||
use nalgebra::DMatrix;
|
||||
|
||||
use proptest::prelude::*;
|
||||
use proptest::sample::subsequence;
|
||||
|
@ -9,7 +9,6 @@ use crate::common::csr_strategy;
|
|||
|
||||
use std::collections::HashSet;
|
||||
|
||||
|
||||
#[test]
|
||||
fn csr_matrix_valid_data() {
|
||||
// Construct matrix from valid data and check that selected methods return results
|
||||
|
@ -43,7 +42,10 @@ fn csr_matrix_valid_data() {
|
|||
assert_eq!(matrix.row_mut(0).col_indices(), &[]);
|
||||
assert_eq!(matrix.row_mut(0).values(), &[]);
|
||||
assert_eq!(matrix.row_mut(0).values_mut(), &[]);
|
||||
assert_eq!(matrix.row_mut(0).cols_and_values_mut(), ([].as_ref(), [].as_mut()));
|
||||
assert_eq!(
|
||||
matrix.row_mut(0).cols_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.row(1).ncols(), 2);
|
||||
assert_eq!(matrix.row(1).nnz(), 0);
|
||||
|
@ -54,7 +56,10 @@ fn csr_matrix_valid_data() {
|
|||
assert_eq!(matrix.row_mut(1).col_indices(), &[]);
|
||||
assert_eq!(matrix.row_mut(1).values(), &[]);
|
||||
assert_eq!(matrix.row_mut(1).values_mut(), &[]);
|
||||
assert_eq!(matrix.row_mut(1).cols_and_values_mut(), ([].as_ref(), [].as_mut()));
|
||||
assert_eq!(
|
||||
matrix.row_mut(1).cols_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.row(2).ncols(), 2);
|
||||
assert_eq!(matrix.row(2).nnz(), 0);
|
||||
|
@ -65,7 +70,10 @@ fn csr_matrix_valid_data() {
|
|||
assert_eq!(matrix.row_mut(2).col_indices(), &[]);
|
||||
assert_eq!(matrix.row_mut(2).values(), &[]);
|
||||
assert_eq!(matrix.row_mut(2).values_mut(), &[]);
|
||||
assert_eq!(matrix.row_mut(2).cols_and_values_mut(), ([].as_ref(), [].as_mut()));
|
||||
assert_eq!(
|
||||
matrix.row_mut(2).cols_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert!(matrix.get_row(3).is_none());
|
||||
assert!(matrix.get_row_mut(3).is_none());
|
||||
|
@ -82,11 +90,9 @@ fn csr_matrix_valid_data() {
|
|||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let mut matrix = CsrMatrix::try_from_csr_data(3,
|
||||
6,
|
||||
offsets.clone(),
|
||||
indices.clone(),
|
||||
values.clone()).unwrap();
|
||||
let mut matrix =
|
||||
CsrMatrix::try_from_csr_data(3, 6, offsets.clone(), indices.clone(), values.clone())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(matrix.nrows(), 3);
|
||||
assert_eq!(matrix.ncols(), 6);
|
||||
|
@ -96,10 +102,20 @@ fn csr_matrix_valid_data() {
|
|||
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
|
||||
|
||||
let expected_triplets = vec![(0, 0, 0), (0, 5, 1), (2, 1, 2), (2, 2, 3), (2, 3, 4)];
|
||||
assert_eq!(matrix.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(),
|
||||
expected_triplets);
|
||||
assert_eq!(matrix.triplet_iter_mut().map(|(i, j, v)| (i, j, *v)).collect::<Vec<_>>(),
|
||||
expected_triplets);
|
||||
assert_eq!(
|
||||
matrix
|
||||
.triplet_iter()
|
||||
.map(|(i, j, v)| (i, j, *v))
|
||||
.collect::<Vec<_>>(),
|
||||
expected_triplets
|
||||
);
|
||||
assert_eq!(
|
||||
matrix
|
||||
.triplet_iter_mut()
|
||||
.map(|(i, j, v)| (i, j, *v))
|
||||
.collect::<Vec<_>>(),
|
||||
expected_triplets
|
||||
);
|
||||
|
||||
assert_eq!(matrix.row(0).ncols(), 6);
|
||||
assert_eq!(matrix.row(0).nnz(), 2);
|
||||
|
@ -110,7 +126,10 @@ fn csr_matrix_valid_data() {
|
|||
assert_eq!(matrix.row_mut(0).col_indices(), &[0, 5]);
|
||||
assert_eq!(matrix.row_mut(0).values(), &[0, 1]);
|
||||
assert_eq!(matrix.row_mut(0).values_mut(), &[0, 1]);
|
||||
assert_eq!(matrix.row_mut(0).cols_and_values_mut(), ([0, 5].as_ref(), [0, 1].as_mut()));
|
||||
assert_eq!(
|
||||
matrix.row_mut(0).cols_and_values_mut(),
|
||||
([0, 5].as_ref(), [0, 1].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.row(1).ncols(), 6);
|
||||
assert_eq!(matrix.row(1).nnz(), 0);
|
||||
|
@ -121,7 +140,10 @@ fn csr_matrix_valid_data() {
|
|||
assert_eq!(matrix.row_mut(1).col_indices(), &[]);
|
||||
assert_eq!(matrix.row_mut(1).values(), &[]);
|
||||
assert_eq!(matrix.row_mut(1).values_mut(), &[]);
|
||||
assert_eq!(matrix.row_mut(1).cols_and_values_mut(), ([].as_ref(), [].as_mut()));
|
||||
assert_eq!(
|
||||
matrix.row_mut(1).cols_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.row(2).ncols(), 6);
|
||||
assert_eq!(matrix.row(2).nnz(), 3);
|
||||
|
@ -132,7 +154,10 @@ fn csr_matrix_valid_data() {
|
|||
assert_eq!(matrix.row_mut(2).col_indices(), &[1, 2, 3]);
|
||||
assert_eq!(matrix.row_mut(2).values(), &[2, 3, 4]);
|
||||
assert_eq!(matrix.row_mut(2).values_mut(), &[2, 3, 4]);
|
||||
assert_eq!(matrix.row_mut(2).cols_and_values_mut(), ([1, 2, 3].as_ref(), [2, 3, 4].as_mut()));
|
||||
assert_eq!(
|
||||
matrix.row_mut(2).cols_and_values_mut(),
|
||||
([1, 2, 3].as_ref(), [2, 3, 4].as_mut())
|
||||
);
|
||||
|
||||
assert!(matrix.get_row(3).is_none());
|
||||
assert!(matrix.get_row_mut(3).is_none());
|
||||
|
@ -147,11 +172,13 @@ fn csr_matrix_valid_data() {
|
|||
|
||||
#[test]
|
||||
fn csr_matrix_try_from_invalid_csr_data() {
|
||||
|
||||
{
|
||||
// Empty offset array (invalid length)
|
||||
let matrix = CsrMatrix::try_from_csr_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -161,7 +188,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
|||
let values = vec![0, 1, 2, 3, 4];
|
||||
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -170,7 +200,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
|||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -179,7 +212,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
|||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -188,7 +224,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
|||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -197,7 +236,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
|||
let indices = vec![0, 1, 2, 3, 4];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -206,7 +248,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
|||
let indices = vec![0, 2, 3, 1, 4];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::InvalidStructure);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -215,7 +260,10 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
|||
let indices = vec![0, 6, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::IndexOutOfBounds);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::IndexOutOfBounds
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -224,9 +272,11 @@ fn csr_matrix_try_from_invalid_csr_data() {
|
|||
let indices = vec![0, 5, 2, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(matrix.unwrap_err().kind(), &SparseFormatErrorKind::DuplicateEntry);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::DuplicateEntry
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -240,11 +290,7 @@ fn csr_disassemble_avoids_clone_when_owned() {
|
|||
let offsets_ptr = offsets.as_ptr();
|
||||
let indices_ptr = indices.as_ptr();
|
||||
let values_ptr = values.as_ptr();
|
||||
let matrix = CsrMatrix::try_from_csr_data(3,
|
||||
6,
|
||||
offsets,
|
||||
indices,
|
||||
values).unwrap();
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values).unwrap();
|
||||
|
||||
let (offsets, indices, values) = matrix.disassemble();
|
||||
assert_eq!(offsets.as_ptr(), offsets_ptr);
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
mod coo;
|
||||
mod cholesky;
|
||||
mod convert_serial;
|
||||
mod coo;
|
||||
mod csc;
|
||||
mod csr;
|
||||
mod ops;
|
||||
mod pattern;
|
||||
mod csr;
|
||||
mod csc;
|
||||
mod convert_serial;
|
||||
mod proptest;
|
|
@ -1,13 +1,19 @@
|
|||
use crate::common::{csc_strategy, csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ, PROPTEST_I32_VALUE_STRATEGY, non_zero_i32_value_strategy, value_strategy};
|
||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spadd_csr_prealloc, spadd_csc_prealloc, spmm_csr_prealloc, spmm_csc_prealloc, spsolve_csc_lower_triangular, spmm_csr_pattern};
|
||||
use nalgebra_sparse::ops::{Op};
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use crate::common::{
|
||||
csc_strategy, csr_strategy, non_zero_i32_value_strategy, value_strategy,
|
||||
PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
||||
};
|
||||
use nalgebra_sparse::csc::CscMatrix;
|
||||
use nalgebra_sparse::proptest::{csc, csr, sparsity_pattern};
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::ops::serial::{
|
||||
spadd_csc_prealloc, spadd_csr_prealloc, spadd_pattern, spmm_csc_dense, spmm_csc_prealloc,
|
||||
spmm_csr_dense, spmm_csr_pattern, spmm_csr_prealloc, spsolve_csc_lower_triangular,
|
||||
};
|
||||
use nalgebra_sparse::ops::Op;
|
||||
use nalgebra_sparse::pattern::SparsityPattern;
|
||||
use nalgebra_sparse::proptest::{csc, csr, sparsity_pattern};
|
||||
|
||||
use nalgebra::{DMatrix, Scalar, DMatrixSliceMut, DMatrixSlice};
|
||||
use nalgebra::proptest::{matrix, vector};
|
||||
use nalgebra::{DMatrix, DMatrixSlice, DMatrixSliceMut, Scalar};
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
|
@ -17,19 +23,15 @@ use std::panic::catch_unwind;
|
|||
|
||||
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
|
||||
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||
let boolean_csr = CsrMatrix::try_from_pattern_and_values(
|
||||
pattern.clone(),
|
||||
vec![1; pattern.nnz()])
|
||||
.unwrap();
|
||||
let boolean_csr =
|
||||
CsrMatrix::try_from_pattern_and_values(pattern.clone(), vec![1; pattern.nnz()]).unwrap();
|
||||
DMatrix::from(&boolean_csr)
|
||||
}
|
||||
|
||||
/// Represents the sparsity pattern of a CSC matrix as a dense matrix with 0/1
|
||||
fn dense_csc_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||
let boolean_csc = CscMatrix::try_from_pattern_and_values(
|
||||
pattern.clone(),
|
||||
vec![1; pattern.nnz()])
|
||||
.unwrap();
|
||||
let boolean_csc =
|
||||
CscMatrix::try_from_pattern_and_values(pattern.clone(), vec![1; pattern.nnz()]).unwrap();
|
||||
DMatrix::from(&boolean_csc)
|
||||
}
|
||||
|
||||
|
@ -53,7 +55,7 @@ struct SpmmCscDenseArgs<T: Scalar> {
|
|||
|
||||
/// Returns matrices C, A and B with compatible dimensions such that it can be used
|
||||
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
|
||||
fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>> {
|
||||
fn spmm_csr_dense_args_strategy() -> impl Strategy<Value = SpmmCsrDenseArgs<i32>> {
|
||||
let max_nnz = PROPTEST_MAX_NNZ;
|
||||
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
||||
let c_rows = PROPTEST_MATRIX_DIM;
|
||||
|
@ -62,14 +64,23 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
|
|||
let trans_strategy = trans_strategy();
|
||||
let c_matrix_strategy = matrix(value_strategy.clone(), c_rows, c_cols);
|
||||
|
||||
(c_matrix_strategy, common_dim, trans_strategy.clone(), trans_strategy.clone())
|
||||
(
|
||||
c_matrix_strategy,
|
||||
common_dim,
|
||||
trans_strategy.clone(),
|
||||
trans_strategy.clone(),
|
||||
)
|
||||
.prop_flat_map(move |(c, common_dim, trans_a, trans_b)| {
|
||||
let a_shape =
|
||||
if trans_a { (common_dim, c.nrows()) }
|
||||
else { (c.nrows(), common_dim) };
|
||||
let b_shape =
|
||||
if trans_b { (c.ncols(), common_dim) }
|
||||
else { (common_dim, c.ncols()) };
|
||||
let a_shape = if trans_a {
|
||||
(common_dim, c.nrows())
|
||||
} else {
|
||||
(c.nrows(), common_dim)
|
||||
};
|
||||
let b_shape = if trans_b {
|
||||
(c.ncols(), common_dim)
|
||||
} else {
|
||||
(common_dim, c.ncols())
|
||||
};
|
||||
let a = csr(value_strategy.clone(), a_shape.0, a_shape.1, max_nnz);
|
||||
let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1);
|
||||
|
||||
|
@ -78,30 +89,36 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
|
|||
let beta = value_strategy.clone();
|
||||
|
||||
(Just(c), beta, alpha, Just(trans_a), a, Just(trans_b), b)
|
||||
}).prop_map(|(c, beta, alpha, trans_a, a, trans_b, b)| {
|
||||
SpmmCsrDenseArgs {
|
||||
})
|
||||
.prop_map(
|
||||
|(c, beta, alpha, trans_a, a, trans_b, b)| SpmmCsrDenseArgs {
|
||||
c,
|
||||
beta,
|
||||
alpha,
|
||||
a: if trans_a { Op::Transpose(a) } else { Op::NoOp(a) },
|
||||
b: if trans_b { Op::Transpose(b) } else { Op::NoOp(b) },
|
||||
}
|
||||
})
|
||||
a: if trans_a {
|
||||
Op::Transpose(a)
|
||||
} else {
|
||||
Op::NoOp(a)
|
||||
},
|
||||
b: if trans_b {
|
||||
Op::Transpose(b)
|
||||
} else {
|
||||
Op::NoOp(b)
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns matrices C, A and B with compatible dimensions such that it can be used
|
||||
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
|
||||
fn spmm_csc_dense_args_strategy() -> impl Strategy<Value=SpmmCscDenseArgs<i32>> {
|
||||
spmm_csr_dense_args_strategy()
|
||||
.prop_map(|args| {
|
||||
SpmmCscDenseArgs {
|
||||
c: args.c,
|
||||
beta: args.beta,
|
||||
alpha: args.alpha,
|
||||
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
|
||||
b: args.b
|
||||
}
|
||||
})
|
||||
fn spmm_csc_dense_args_strategy() -> impl Strategy<Value = SpmmCscDenseArgs<i32>> {
|
||||
spmm_csr_dense_args_strategy().prop_map(|args| SpmmCscDenseArgs {
|
||||
c: args.c,
|
||||
beta: args.beta,
|
||||
alpha: args.alpha,
|
||||
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
|
||||
b: args.b,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -120,7 +137,7 @@ struct SpaddCscArgs<T> {
|
|||
a: Op<CscMatrix<T>>,
|
||||
}
|
||||
|
||||
fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
||||
fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value = SpaddCsrArgs<i32>> {
|
||||
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
||||
|
||||
spadd_pattern_strategy()
|
||||
|
@ -131,66 +148,83 @@ fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>>
|
|||
let c_values = vec![value_strategy.clone(); c_pattern.nnz()];
|
||||
let alpha = value_strategy.clone();
|
||||
let beta = value_strategy.clone();
|
||||
(Just(c_pattern), Just(a_pattern), c_values, a_values, alpha, beta, trans_strategy())
|
||||
}).prop_map(|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| {
|
||||
let c = CsrMatrix::try_from_pattern_and_values(c_pattern, c_values).unwrap();
|
||||
let a = CsrMatrix::try_from_pattern_and_values(a_pattern, a_values).unwrap();
|
||||
|
||||
let a = if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) };
|
||||
SpaddCsrArgs { c, beta, alpha, a }
|
||||
(
|
||||
Just(c_pattern),
|
||||
Just(a_pattern),
|
||||
c_values,
|
||||
a_values,
|
||||
alpha,
|
||||
beta,
|
||||
trans_strategy(),
|
||||
)
|
||||
})
|
||||
.prop_map(
|
||||
|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| {
|
||||
let c = CsrMatrix::try_from_pattern_and_values(c_pattern, c_values).unwrap();
|
||||
let a = CsrMatrix::try_from_pattern_and_values(a_pattern, a_values).unwrap();
|
||||
|
||||
let a = if trans_a {
|
||||
Op::Transpose(a.transpose())
|
||||
} else {
|
||||
Op::NoOp(a)
|
||||
};
|
||||
SpaddCsrArgs { c, beta, alpha, a }
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn spadd_csc_prealloc_args_strategy() -> impl Strategy<Value=SpaddCscArgs<i32>> {
|
||||
spadd_csr_prealloc_args_strategy()
|
||||
.prop_map(|args| SpaddCscArgs {
|
||||
c: CscMatrix::from(&args.c),
|
||||
beta: args.beta,
|
||||
alpha: args.alpha,
|
||||
a: args.a.map_same_op(|a| CscMatrix::from(&a))
|
||||
})
|
||||
fn spadd_csc_prealloc_args_strategy() -> impl Strategy<Value = SpaddCscArgs<i32>> {
|
||||
spadd_csr_prealloc_args_strategy().prop_map(|args| SpaddCscArgs {
|
||||
c: CscMatrix::from(&args.c),
|
||||
beta: args.beta,
|
||||
alpha: args.alpha,
|
||||
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
|
||||
})
|
||||
}
|
||||
|
||||
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
||||
matrix(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM)
|
||||
fn dense_strategy() -> impl Strategy<Value = DMatrix<i32>> {
|
||||
matrix(
|
||||
PROPTEST_I32_VALUE_STRATEGY,
|
||||
PROPTEST_MATRIX_DIM,
|
||||
PROPTEST_MATRIX_DIM,
|
||||
)
|
||||
}
|
||||
|
||||
fn trans_strategy() -> impl Strategy<Value=bool> + Clone {
|
||||
fn trans_strategy() -> impl Strategy<Value = bool> + Clone {
|
||||
proptest::bool::ANY
|
||||
}
|
||||
|
||||
/// Wraps the values of the given strategy in `Op`, producing both transposed and non-transposed
|
||||
/// values.
|
||||
fn op_strategy<S: Strategy>(strategy: S) -> impl Strategy<Value=Op<S::Value>> {
|
||||
fn op_strategy<S: Strategy>(strategy: S) -> impl Strategy<Value = Op<S::Value>> {
|
||||
let is_transposed = proptest::bool::ANY;
|
||||
(strategy, is_transposed)
|
||||
.prop_map(|(obj, is_trans)| if is_trans {
|
||||
(strategy, is_transposed).prop_map(|(obj, is_trans)| {
|
||||
if is_trans {
|
||||
Op::Transpose(obj)
|
||||
} else {
|
||||
Op::NoOp(obj)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn pattern_strategy() -> impl Strategy<Value=SparsityPattern> {
|
||||
fn pattern_strategy() -> impl Strategy<Value = SparsityPattern> {
|
||||
sparsity_pattern(PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
||||
}
|
||||
|
||||
/// Constructs pairs (a, b) where a and b have the same dimensions
|
||||
fn spadd_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
||||
pattern_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let b = sparsity_pattern(a.major_dim(), a.minor_dim(), PROPTEST_MAX_NNZ);
|
||||
(Just(a), b)
|
||||
})
|
||||
fn spadd_pattern_strategy() -> impl Strategy<Value = (SparsityPattern, SparsityPattern)> {
|
||||
pattern_strategy().prop_flat_map(|a| {
|
||||
let b = sparsity_pattern(a.major_dim(), a.minor_dim(), PROPTEST_MAX_NNZ);
|
||||
(Just(a), b)
|
||||
})
|
||||
}
|
||||
|
||||
/// Constructs pairs (a, b) where a and b have compatible dimensions for a matrix product
|
||||
fn spmm_csr_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
||||
pattern_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let b = sparsity_pattern(a.minor_dim(), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
|
||||
(Just(a), b)
|
||||
})
|
||||
fn spmm_csr_pattern_strategy() -> impl Strategy<Value = (SparsityPattern, SparsityPattern)> {
|
||||
pattern_strategy().prop_flat_map(|a| {
|
||||
let b = sparsity_pattern(a.minor_dim(), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
|
||||
(Just(a), b)
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -211,86 +245,98 @@ struct SpmmCscArgs<T> {
|
|||
b: Op<CscMatrix<T>>,
|
||||
}
|
||||
|
||||
fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
||||
fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value = SpmmCsrArgs<i32>> {
|
||||
spmm_csr_pattern_strategy()
|
||||
.prop_flat_map(|(a_pattern, b_pattern)| {
|
||||
let a_values = vec![PROPTEST_I32_VALUE_STRATEGY; a_pattern.nnz()];
|
||||
let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()];
|
||||
let c_pattern = spmm_csr_pattern(&a_pattern, &b_pattern);
|
||||
let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()];
|
||||
let a = a_values.prop_map(move |values|
|
||||
CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap());
|
||||
let b = b_values.prop_map(move |values|
|
||||
CsrMatrix::try_from_pattern_and_values(b_pattern.clone(), values).unwrap());
|
||||
let c = c_values.prop_map(move |values|
|
||||
CsrMatrix::try_from_pattern_and_values(c_pattern.clone(), values).unwrap());
|
||||
let a = a_values.prop_map(move |values| {
|
||||
CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap()
|
||||
});
|
||||
let b = b_values.prop_map(move |values| {
|
||||
CsrMatrix::try_from_pattern_and_values(b_pattern.clone(), values).unwrap()
|
||||
});
|
||||
let c = c_values.prop_map(move |values| {
|
||||
CsrMatrix::try_from_pattern_and_values(c_pattern.clone(), values).unwrap()
|
||||
});
|
||||
let alpha = PROPTEST_I32_VALUE_STRATEGY;
|
||||
let beta = PROPTEST_I32_VALUE_STRATEGY;
|
||||
(c, beta, alpha, trans_strategy(), a, trans_strategy(), b)
|
||||
})
|
||||
.prop_map(|(c, beta, alpha, trans_a, a, trans_b, b)| {
|
||||
SpmmCsrArgs::<i32> {
|
||||
.prop_map(
|
||||
|(c, beta, alpha, trans_a, a, trans_b, b)| SpmmCsrArgs::<i32> {
|
||||
c,
|
||||
beta,
|
||||
alpha,
|
||||
a: if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) },
|
||||
b: if trans_b { Op::Transpose(b.transpose()) } else { Op::NoOp(b) }
|
||||
}
|
||||
})
|
||||
a: if trans_a {
|
||||
Op::Transpose(a.transpose())
|
||||
} else {
|
||||
Op::NoOp(a)
|
||||
},
|
||||
b: if trans_b {
|
||||
Op::Transpose(b.transpose())
|
||||
} else {
|
||||
Op::NoOp(b)
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn spmm_csc_prealloc_args_strategy() -> impl Strategy<Value=SpmmCscArgs<i32>> {
|
||||
fn spmm_csc_prealloc_args_strategy() -> impl Strategy<Value = SpmmCscArgs<i32>> {
|
||||
// Note: Converting from CSR is simple, but might be significantly slower than
|
||||
// writing a common implementation that can be shared between CSR and CSC args
|
||||
spmm_csr_prealloc_args_strategy()
|
||||
.prop_map(|args| {
|
||||
SpmmCscArgs {
|
||||
c: CscMatrix::from(&args.c),
|
||||
beta: args.beta,
|
||||
alpha: args.alpha,
|
||||
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
|
||||
b: args.b.map_same_op(|b| CscMatrix::from(&b))
|
||||
spmm_csr_prealloc_args_strategy().prop_map(|args| SpmmCscArgs {
|
||||
c: CscMatrix::from(&args.c),
|
||||
beta: args.beta,
|
||||
alpha: args.alpha,
|
||||
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
|
||||
b: args.b.map_same_op(|b| CscMatrix::from(&b)),
|
||||
})
|
||||
}
|
||||
|
||||
fn csc_invertible_diagonal() -> impl Strategy<Value = CscMatrix<f64>> {
|
||||
let non_zero_values =
|
||||
value_strategy::<f64>().prop_filter("Only non-zeros values accepted", |x| x != &0.0);
|
||||
|
||||
vector(non_zero_values, PROPTEST_MATRIX_DIM).prop_map(|d| {
|
||||
let mut matrix = CscMatrix::identity(d.len());
|
||||
matrix.values_mut().clone_from_slice(&d.as_slice());
|
||||
matrix
|
||||
})
|
||||
}
|
||||
|
||||
fn csc_square_with_non_zero_diagonals() -> impl Strategy<Value = CscMatrix<f64>> {
|
||||
csc_invertible_diagonal().prop_flat_map(|d| {
|
||||
csc(
|
||||
value_strategy::<f64>(),
|
||||
d.nrows(),
|
||||
d.nrows(),
|
||||
PROPTEST_MAX_NNZ,
|
||||
)
|
||||
.prop_map(move |mut c| {
|
||||
for (i, j, v) in c.triplet_iter_mut() {
|
||||
if i == j {
|
||||
*v = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
// Return the sum of a matrix with zero diagonals and an invertible diagonal
|
||||
// matrix
|
||||
c + &d
|
||||
})
|
||||
}
|
||||
|
||||
fn csc_invertible_diagonal() -> impl Strategy<Value=CscMatrix<f64>> {
|
||||
let non_zero_values = value_strategy::<f64>()
|
||||
.prop_filter("Only non-zeros values accepted", |x| x != &0.0);
|
||||
|
||||
vector(non_zero_values, PROPTEST_MATRIX_DIM)
|
||||
.prop_map(|d| {
|
||||
let mut matrix = CscMatrix::identity(d.len());
|
||||
matrix.values_mut().clone_from_slice(&d.as_slice());
|
||||
matrix
|
||||
})
|
||||
}
|
||||
|
||||
fn csc_square_with_non_zero_diagonals() -> impl Strategy<Value=CscMatrix<f64>> {
|
||||
csc_invertible_diagonal()
|
||||
.prop_flat_map(|d| {
|
||||
csc(value_strategy::<f64>(), d.nrows(), d.nrows(), PROPTEST_MAX_NNZ)
|
||||
.prop_map(move |mut c| {
|
||||
for (i, j, v) in c.triplet_iter_mut() {
|
||||
if i == j {
|
||||
*v = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
// Return the sum of a matrix with zero diagonals and an invertible diagonal
|
||||
// matrix
|
||||
c + &d
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper function to help us call dense GEMM with our `Op` type
|
||||
fn dense_gemm<'a>(beta: i32,
|
||||
c: impl Into<DMatrixSliceMut<'a, i32>>,
|
||||
alpha: i32,
|
||||
a: Op<impl Into<DMatrixSlice<'a, i32>>>,
|
||||
b: Op<impl Into<DMatrixSlice<'a, i32>>>)
|
||||
{
|
||||
fn dense_gemm<'a>(
|
||||
beta: i32,
|
||||
c: impl Into<DMatrixSliceMut<'a, i32>>,
|
||||
alpha: i32,
|
||||
a: Op<impl Into<DMatrixSlice<'a, i32>>>,
|
||||
b: Op<impl Into<DMatrixSlice<'a, i32>>>,
|
||||
) {
|
||||
let mut c = c.into();
|
||||
let a = a.convert();
|
||||
let b = b.convert();
|
||||
|
@ -300,7 +346,7 @@ fn dense_gemm<'a>(beta: i32,
|
|||
(NoOp(a), NoOp(b)) => c.gemm(alpha, &a, &b, beta),
|
||||
(Transpose(a), NoOp(b)) => c.gemm(alpha, &a.transpose(), &b, beta),
|
||||
(NoOp(a), Transpose(b)) => c.gemm(alpha, &a, &b.transpose(), beta),
|
||||
(Transpose(a), Transpose(b)) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta)
|
||||
(Transpose(a), Transpose(b)) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,11 +7,9 @@ fn sparsity_pattern_valid_data() {
|
|||
|
||||
{
|
||||
// A pattern with zero explicitly stored entries
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3,
|
||||
2,
|
||||
vec![0, 0, 0, 0],
|
||||
Vec::new())
|
||||
.unwrap();
|
||||
let pattern =
|
||||
SparsityPattern::try_from_offsets_and_indices(3, 2, vec![0, 0, 0, 0], Vec::new())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(pattern.major_dim(), 3);
|
||||
assert_eq!(pattern.minor_dim(), 2);
|
||||
|
@ -36,7 +34,7 @@ fn sparsity_pattern_valid_data() {
|
|||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let pattern =
|
||||
SparsityPattern::try_from_offsets_and_indices(3, 6, offsets.clone(), indices.clone())
|
||||
.unwrap();
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(pattern.major_dim(), 3);
|
||||
assert_eq!(pattern.minor_dim(), 6);
|
||||
|
@ -46,8 +44,10 @@ fn sparsity_pattern_valid_data() {
|
|||
assert_eq!(pattern.lane(0), &[0, 5]);
|
||||
assert_eq!(pattern.lane(1), &[]);
|
||||
assert_eq!(pattern.lane(2), &[1, 2, 3]);
|
||||
assert_eq!(pattern.entries().collect::<Vec<_>>(),
|
||||
vec![(0, 0), (0, 5), (2, 1), (2, 2), (2, 3)]);
|
||||
assert_eq!(
|
||||
pattern.entries().collect::<Vec<_>>(),
|
||||
vec![(0, 0), (0, 5), (2, 1), (2, 2), (2, 3)]
|
||||
);
|
||||
|
||||
let (offsets2, indices2) = pattern.disassemble();
|
||||
assert_eq!(offsets2, offsets);
|
||||
|
@ -60,7 +60,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
|||
{
|
||||
// Empty offset array (invalid length)
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(0, 0, Vec::new(), Vec::new());
|
||||
assert_eq!(pattern, Err(SparsityPatternFormatError::InvalidOffsetArrayLength));
|
||||
assert_eq!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -69,7 +72,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
|||
let indices = vec![0, 1, 2, 3, 5];
|
||||
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetArrayLength)));
|
||||
assert!(matches!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -77,7 +83,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
|||
let offsets = vec![1, 2, 2, 5];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetFirstLast)));
|
||||
assert!(matches!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::InvalidOffsetFirstLast)
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -85,7 +94,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
|||
let offsets = vec![0, 2, 2, 4];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetFirstLast)));
|
||||
assert!(matches!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::InvalidOffsetFirstLast)
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -93,7 +105,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
|||
let offsets = vec![0, 2, 2];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||
assert!(matches!(pattern, Err(SparsityPatternFormatError::InvalidOffsetArrayLength)));
|
||||
assert!(matches!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -101,7 +116,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
|||
let offsets = vec![0, 3, 2, 5];
|
||||
let indices = vec![0, 1, 2, 3, 4];
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||
assert_eq!(pattern, Err(SparsityPatternFormatError::NonmonotonicOffsets));
|
||||
assert_eq!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::NonmonotonicOffsets)
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -109,7 +127,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
|||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 2, 3, 1, 4];
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||
assert_eq!(pattern, Err(SparsityPatternFormatError::NonmonotonicMinorIndices));
|
||||
assert_eq!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::NonmonotonicMinorIndices)
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -117,7 +138,10 @@ fn sparsity_pattern_try_from_invalid_data() {
|
|||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 6, 1, 2, 3];
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||
assert_eq!(pattern, Err(SparsityPatternFormatError::MinorIndexOutOfBounds));
|
||||
assert_eq!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::MinorIndexOutOfBounds)
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
|
|
|
@ -6,25 +6,27 @@ fn coo_no_duplicates_generates_admissible_matrices() {
|
|||
|
||||
#[cfg(feature = "slow-tests")]
|
||||
mod slow {
|
||||
use nalgebra_sparse::proptest::{coo_with_duplicates, coo_no_duplicates, csr, csc, sparsity_pattern};
|
||||
use nalgebra::DMatrix;
|
||||
use nalgebra_sparse::proptest::{
|
||||
coo_no_duplicates, coo_with_duplicates, csc, csr, sparsity_pattern,
|
||||
};
|
||||
|
||||
use proptest::test_runner::TestRunner;
|
||||
use proptest::strategy::ValueTree;
|
||||
use itertools::Itertools;
|
||||
use proptest::strategy::ValueTree;
|
||||
use proptest::test_runner::TestRunner;
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use std::collections::HashSet;
|
||||
use std::iter::repeat;
|
||||
use std::ops::RangeInclusive;
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
|
||||
fn generate_all_possible_matrices(value_range: RangeInclusive<i32>,
|
||||
rows_range: RangeInclusive<usize>,
|
||||
cols_range: RangeInclusive<usize>)
|
||||
-> HashSet<DMatrix<i32>>
|
||||
{
|
||||
fn generate_all_possible_matrices(
|
||||
value_range: RangeInclusive<i32>,
|
||||
rows_range: RangeInclusive<usize>,
|
||||
cols_range: RangeInclusive<usize>,
|
||||
) -> HashSet<DMatrix<i32>> {
|
||||
// Enumerate all possible combinations
|
||||
let mut all_combinations = HashSet::new();
|
||||
for nrows in rows_range {
|
||||
|
@ -48,7 +50,11 @@ mod slow {
|
|||
.take(n_values)
|
||||
.multi_cartesian_product();
|
||||
for matrix_values in values_iter {
|
||||
all_combinations.insert(DMatrix::from_row_slice(nrows, ncols, &matrix_values));
|
||||
all_combinations.insert(DMatrix::from_row_slice(
|
||||
nrows,
|
||||
ncols,
|
||||
&matrix_values,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -80,12 +86,14 @@ mod slow {
|
|||
// Enumerate all possible combinations
|
||||
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||
|
||||
let visited_combinations = sample_matrix_output_space(strategy,
|
||||
&mut runner,
|
||||
num_generated_matrices);
|
||||
let visited_combinations =
|
||||
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||
|
||||
assert_eq!(visited_combinations.len(), all_combinations.len());
|
||||
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values.");
|
||||
assert_eq!(
|
||||
visited_combinations, all_combinations,
|
||||
"Did not sample all possible values."
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "slow-tests")]
|
||||
|
@ -113,9 +121,8 @@ mod slow {
|
|||
// `coo_with_duplicates`)
|
||||
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||
|
||||
let visited_combinations = sample_matrix_output_space(strategy,
|
||||
&mut runner,
|
||||
num_generated_matrices);
|
||||
let visited_combinations =
|
||||
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||
|
||||
// Here we cannot verify that the set of visited combinations is *equal* to
|
||||
// all possible outcomes with the given constraints, however the
|
||||
|
@ -143,12 +150,14 @@ mod slow {
|
|||
|
||||
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||
|
||||
let visited_combinations = sample_matrix_output_space(strategy,
|
||||
&mut runner,
|
||||
num_generated_matrices);
|
||||
let visited_combinations =
|
||||
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||
|
||||
assert_eq!(visited_combinations.len(), all_combinations.len());
|
||||
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values.");
|
||||
assert_eq!(
|
||||
visited_combinations, all_combinations,
|
||||
"Did not sample all possible values."
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "slow-tests")]
|
||||
|
@ -169,12 +178,14 @@ mod slow {
|
|||
|
||||
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||
|
||||
let visited_combinations = sample_matrix_output_space(strategy,
|
||||
&mut runner,
|
||||
num_generated_matrices);
|
||||
let visited_combinations =
|
||||
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||
|
||||
assert_eq!(visited_combinations.len(), all_combinations.len());
|
||||
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values.");
|
||||
assert_eq!(
|
||||
visited_combinations, all_combinations,
|
||||
"Did not sample all possible values."
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "slow-tests")]
|
||||
|
@ -206,13 +217,14 @@ mod slow {
|
|||
assert_eq!(visited_patterns, all_possible_patterns);
|
||||
}
|
||||
|
||||
fn sample_matrix_output_space<S>(strategy: S,
|
||||
runner: &mut TestRunner,
|
||||
num_samples: usize)
|
||||
-> HashSet<DMatrix<i32>>
|
||||
fn sample_matrix_output_space<S>(
|
||||
strategy: S,
|
||||
runner: &mut TestRunner,
|
||||
num_samples: usize,
|
||||
) -> HashSet<DMatrix<i32>>
|
||||
where
|
||||
S: Strategy,
|
||||
DMatrix<i32>: for<'b> From<&'b S::Value>
|
||||
DMatrix<i32>: for<'b> From<&'b S::Value>,
|
||||
{
|
||||
sample_strategy(strategy, runner)
|
||||
.take(num_samples)
|
||||
|
@ -220,8 +232,10 @@ mod slow {
|
|||
.collect()
|
||||
}
|
||||
|
||||
fn sample_strategy<'a, S: 'a + Strategy>(strategy: S, runner: &'a mut TestRunner)
|
||||
-> impl 'a + Iterator<Item=S::Value> {
|
||||
fn sample_strategy<'a, S: 'a + Strategy>(
|
||||
strategy: S,
|
||||
runner: &'a mut TestRunner,
|
||||
) -> impl 'a + Iterator<Item = S::Value> {
|
||||
repeat(()).map(move |_| {
|
||||
let tree = strategy
|
||||
.new_tree(runner)
|
||||
|
|
Loading…
Reference in New Issue