Refactor most of Csr/CscMatrix logic into helper type CsMatrix

Still need to update CSC API so that it mirrors CsrMatrix
in terms of get_entry and so on.
This commit is contained in:
Andreas Longva 2020-12-22 10:19:17 +01:00
parent 8983027b39
commit b59c4a3216
4 changed files with 394 additions and 261 deletions

286
nalgebra-sparse/src/cs.rs Normal file
View File

@ -0,0 +1,286 @@
use crate::pattern::SparsityPattern;
use crate::{SparseEntry, SparseEntryMut};
use std::sync::Arc;
use std::ops::Range;
use std::ptr::slice_from_raw_parts_mut;
/// An abstract compressed matrix.
///
/// For the time being, this is only used internally to share implementation between
/// CSR and CSC matrices.
///
/// A CSR matrix is obtained by associating rows with the major dimension, while a CSC matrix
/// is obtained by associating columns with the major dimension.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsMatrix<T> {
sparsity_pattern: Arc<SparsityPattern>,
values: Vec<T>
}
impl<T> CsMatrix<T> {
/// Create a zero matrix with no explicitly stored entries.
#[inline]
pub fn new(major_dim: usize, minor_dim: usize) -> Self {
Self {
sparsity_pattern: Arc::new(SparsityPattern::new(major_dim, minor_dim)),
values: vec![],
}
}
#[inline]
pub fn pattern(&self) -> &Arc<SparsityPattern> {
&self.sparsity_pattern
}
#[inline]
pub fn values(&self) -> &[T] {
&self.values
}
#[inline]
pub fn values_mut(&mut self) -> &mut [T] {
&mut self.values
}
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
#[inline]
pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) {
let pattern = self.pattern().as_ref();
(pattern.major_offsets(), pattern.minor_indices(), &self.values)
}
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
#[inline]
pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
let pattern = self.sparsity_pattern.as_ref();
(pattern.major_offsets(), pattern.minor_indices(), &mut self.values)
}
#[inline]
pub fn pattern_and_values_mut(&mut self) -> (&Arc<SparsityPattern>, &mut [T]) {
(&self.sparsity_pattern, &mut self.values)
}
#[inline]
pub fn from_pattern_and_values(pattern: Arc<SparsityPattern>, values: Vec<T>)
-> Self {
assert_eq!(pattern.nnz(), values.len(), "Internal error: consumers should verify shape compatibility.");
Self {
sparsity_pattern: pattern,
values,
}
}
/// Internal method for simplifying access to a lane's data
#[inline]
pub fn get_index_range(&self, row_index: usize) -> Option<Range<usize>> {
let row_begin = *self.sparsity_pattern.major_offsets().get(row_index)?;
let row_end = *self.sparsity_pattern.major_offsets().get(row_index + 1)?;
Some(row_begin .. row_end)
}
pub fn take_pattern_and_values(self) -> (Arc<SparsityPattern>, Vec<T>) {
(self.sparsity_pattern, self.values)
}
#[inline]
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
// Take an Arc to the pattern, which might be the sole reference to the data after
// taking the values. This is important, because it might let us avoid cloning the data
// further below.
let pattern = self.sparsity_pattern;
let values = self.values;
// Try to take the pattern out of the `Arc` if possible,
// otherwise clone the pattern.
let owned_pattern = Arc::try_unwrap(pattern)
.unwrap_or_else(|arc| SparsityPattern::clone(&*arc));
let (offsets, indices) = owned_pattern.disassemble();
(offsets, indices, values)
}
/// Returns an entry for the given major/minor indices, or `None` if the indices are out
/// of bounds.
pub fn get_entry(&self, major_index: usize, minor_index: usize) -> Option<SparseEntry<T>> {
let row_range = self.get_index_range(major_index)?;
let (_, minor_indices, values) = self.cs_data();
let minor_indices = &minor_indices[row_range.clone()];
let values = &values[row_range];
get_entry_from_slices(self.pattern().minor_dim(), minor_indices, values, minor_index)
}
/// Returns a mutable entry for the given major/minor indices, or `None` if the indices are out
/// of bounds.
pub fn get_entry_mut(&mut self, major_index: usize, minor_index: usize)
-> Option<SparseEntryMut<T>> {
let row_range = self.get_index_range(major_index)?;
let minor_dim = self.pattern().minor_dim();
let (_, minor_indices, values) = self.cs_data_mut();
let minor_indices = &minor_indices[row_range.clone()];
let values = &mut values[row_range];
get_mut_entry_from_slices(minor_dim, minor_indices, values, minor_index)
}
pub fn get_lane(&self, index: usize) -> Option<CsLane<T>> {
let range = self.get_index_range(index)?;
let (_, minor_indices, values) = self.cs_data();
Some(CsLane {
minor_indices: &minor_indices[range.clone()],
values: &values[range],
minor_dim: self.pattern().minor_dim()
})
}
#[inline]
pub fn get_lane_mut(&mut self, index: usize) -> Option<CsLaneMut<T>> {
let range = self.get_index_range(index)?;
let minor_dim = self.pattern().minor_dim();
let (_, minor_indices, values) = self.cs_data_mut();
Some(CsLaneMut {
minor_dim,
minor_indices: &minor_indices[range.clone()],
values: &mut values[range]
})
}
}
pub fn get_entry_from_slices<'a, T>(
minor_dim: usize,
minor_indices: &'a [usize],
values: &'a [T],
global_minor_index: usize) -> Option<SparseEntry<'a, T>> {
let local_index = minor_indices.binary_search(&global_minor_index);
if let Ok(local_index) = local_index {
Some(SparseEntry::NonZero(&values[local_index]))
} else if global_minor_index < minor_dim {
Some(SparseEntry::Zero)
} else {
None
}
}
pub fn get_mut_entry_from_slices<'a, T>(
minor_dim: usize,
minor_indices: &'a [usize],
values: &'a mut [T],
global_minor_indices: usize) -> Option<SparseEntryMut<'a, T>> {
let local_index = minor_indices.binary_search(&global_minor_indices);
if let Ok(local_index) = local_index {
Some(SparseEntryMut::NonZero(&mut values[local_index]))
} else if global_minor_indices < minor_dim {
Some(SparseEntryMut::Zero)
} else {
None
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsLane<'a, T> {
pub minor_dim: usize,
pub minor_indices: &'a [usize],
pub values: &'a [T]
}
#[derive(Debug, PartialEq, Eq)]
pub struct CsLaneMut<'a, T> {
pub minor_dim: usize,
pub minor_indices: &'a [usize],
pub values: &'a mut [T]
}
pub struct CsLaneIter<'a, T> {
// The index of the lane that will be returned on the next iteration
current_lane_idx: usize,
pattern: &'a SparsityPattern,
remaining_values: &'a [T],
}
impl<'a, T> CsLaneIter<'a, T> {
pub fn new(pattern: &'a SparsityPattern, values: &'a [T]) -> Self {
Self {
current_lane_idx: 0,
pattern,
remaining_values: values
}
}
}
impl<'a, T> Iterator for CsLaneIter<'a, T>
where
T: 'a
{
type Item = CsLane<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
let lane = self.pattern.get_lane(self.current_lane_idx);
let minor_dim = self.pattern.minor_dim();
if let Some(minor_indices) = lane {
let count = minor_indices.len();
let values_in_lane = &self.remaining_values[..count];
self.remaining_values = &self.remaining_values[count ..];
self.current_lane_idx += 1;
Some(CsLane {
minor_dim,
minor_indices,
values: values_in_lane
})
} else {
None
}
}
}
pub struct CsLaneIterMut<'a, T> {
// The index of the lane that will be returned on the next iteration
current_lane_idx: usize,
pattern: &'a SparsityPattern,
remaining_values: *mut T,
}
impl<'a, T> CsLaneIterMut<'a, T> {
pub fn new(pattern: &'a SparsityPattern, values: &'a mut [T]) -> Self {
Self {
current_lane_idx: 0,
pattern,
remaining_values: values.as_mut_ptr()
}
}
}
impl<'a, T> Iterator for CsLaneIterMut<'a, T>
where
T: 'a
{
type Item = CsLaneMut<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
let lane = self.pattern.get_lane(self.current_lane_idx);
let minor_dim = self.pattern.minor_dim();
if let Some(minor_indices) = lane {
let count = minor_indices.len();
// Note: I can't think of any way to construct this iterator without unsafe.
let values_in_lane;
unsafe {
values_in_lane = &mut *slice_from_raw_parts_mut(self.remaining_values, count);
self.remaining_values = self.remaining_values.add(count);
}
self.current_lane_idx += 1;
Some(CsLaneMut {
minor_dim,
minor_indices,
values: values_in_lane
})
} else {
None
}
}
}

View File

@ -2,13 +2,12 @@
use crate::{SparseFormatError, SparseFormatErrorKind}; use crate::{SparseFormatError, SparseFormatErrorKind};
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::csr::CsrMatrix;
use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut};
use std::sync::Arc; use std::sync::Arc;
use std::slice::{IterMut, Iter}; use std::slice::{IterMut, Iter};
use std::ops::Range;
use num_traits::Zero; use num_traits::Zero;
use std::ptr::slice_from_raw_parts_mut;
use crate::csr::CsrMatrix;
use nalgebra::Scalar; use nalgebra::Scalar;
/// A CSC representation of a sparse matrix. /// A CSC representation of a sparse matrix.
@ -21,29 +20,27 @@ use nalgebra::Scalar;
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct CscMatrix<T> { pub struct CscMatrix<T> {
// Cols are major, rows are minor in the sparsity pattern // Cols are major, rows are minor in the sparsity pattern
sparsity_pattern: Arc<SparsityPattern>, cs: CsMatrix<T>,
values: Vec<T>,
} }
impl<T> CscMatrix<T> { impl<T> CscMatrix<T> {
/// Create a zero CSC matrix with no explicitly stored entries. /// Create a zero CSC matrix with no explicitly stored entries.
pub fn new(nrows: usize, ncols: usize) -> Self { pub fn new(nrows: usize, ncols: usize) -> Self {
Self { Self {
sparsity_pattern: Arc::new(SparsityPattern::new(ncols, nrows)), cs: CsMatrix::new(ncols, nrows)
values: vec![],
} }
} }
/// The number of rows in the matrix. /// The number of rows in the matrix.
#[inline] #[inline]
pub fn nrows(&self) -> usize { pub fn nrows(&self) -> usize {
self.sparsity_pattern.minor_dim() self.cs.pattern().minor_dim()
} }
/// The number of columns in the matrix. /// The number of columns in the matrix.
#[inline] #[inline]
pub fn ncols(&self) -> usize { pub fn ncols(&self) -> usize {
self.sparsity_pattern.major_dim() self.cs.pattern().major_dim()
} }
/// The number of non-zeros in the matrix. /// The number of non-zeros in the matrix.
@ -53,31 +50,31 @@ impl<T> CscMatrix<T> {
/// be zero. Corresponds to the number of entries in the sparsity pattern. /// be zero. Corresponds to the number of entries in the sparsity pattern.
#[inline] #[inline]
pub fn nnz(&self) -> usize { pub fn nnz(&self) -> usize {
self.sparsity_pattern.nnz() self.pattern().nnz()
} }
/// The column offsets defining part of the CSC format. /// The column offsets defining part of the CSC format.
#[inline] #[inline]
pub fn col_offsets(&self) -> &[usize] { pub fn col_offsets(&self) -> &[usize] {
self.sparsity_pattern.major_offsets() self.pattern().major_offsets()
} }
/// The row indices defining part of the CSC format. /// The row indices defining part of the CSC format.
#[inline] #[inline]
pub fn row_indices(&self) -> &[usize] { pub fn row_indices(&self) -> &[usize] {
self.sparsity_pattern.minor_indices() self.pattern().minor_indices()
} }
/// The non-zero values defining part of the CSC format. /// The non-zero values defining part of the CSC format.
#[inline] #[inline]
pub fn values(&self) -> &[T] { pub fn values(&self) -> &[T] {
&self.values self.cs.values()
} }
/// Mutable access to the non-zero values. /// Mutable access to the non-zero values.
#[inline] #[inline]
pub fn values_mut(&mut self) -> &mut [T] { pub fn values_mut(&mut self) -> &mut [T] {
&mut self.values self.cs.values_mut()
} }
/// Try to construct a CSC matrix from raw CSC data. /// Try to construct a CSC matrix from raw CSC data.
@ -109,8 +106,7 @@ impl<T> CscMatrix<T> {
-> Result<Self, SparseFormatError> { -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() { if pattern.nnz() == values.len() {
Ok(Self { Ok(Self {
sparsity_pattern: pattern, cs: CsMatrix::from_pattern_and_values(pattern, values)
values,
}) })
} else { } else {
Err(SparseFormatError::from_kind_and_msg( Err(SparseFormatError::from_kind_and_msg(
@ -140,8 +136,8 @@ impl<T> CscMatrix<T> {
/// ``` /// ```
pub fn triplet_iter(&self) -> CscTripletIter<T> { pub fn triplet_iter(&self) -> CscTripletIter<T> {
CscTripletIter { CscTripletIter {
pattern_iter: self.sparsity_pattern.entries(), pattern_iter: self.pattern().entries(),
values_iter: self.values.iter() values_iter: self.values().iter()
} }
} }
@ -169,9 +165,10 @@ impl<T> CscMatrix<T> {
/// assert_eq!(triplets, vec![(0, 0, 1), (2, 0, 0), (1, 1, 2), (0, 2, 4)]); /// assert_eq!(triplets, vec![(0, 0, 1), (2, 0, 0), (1, 1, 2), (0, 2, 4)]);
/// ``` /// ```
pub fn triplet_iter_mut(&mut self) -> CscTripletIterMut<T> { pub fn triplet_iter_mut(&mut self) -> CscTripletIterMut<T> {
let (pattern, values) = self.cs.pattern_and_values_mut();
CscTripletIterMut { CscTripletIterMut {
pattern_iter: self.sparsity_pattern.entries(), pattern_iter: pattern.entries(),
values_mut_iter: self.values.iter_mut() values_mut_iter: values.iter_mut()
} }
} }
@ -200,54 +197,34 @@ impl<T> CscMatrix<T> {
/// Return the column at the given column index, or `None` if out of bounds. /// Return the column at the given column index, or `None` if out of bounds.
#[inline] #[inline]
pub fn get_col(&self, index: usize) -> Option<CscCol<T>> { pub fn get_col(&self, index: usize) -> Option<CscCol<T>> {
let range = self.get_index_range(index)?; self.cs
Some(CscCol { .get_lane(index)
row_indices: &self.sparsity_pattern.minor_indices()[range.clone()], .map(|lane| CscCol { lane })
values: &self.values[range],
nrows: self.nrows()
})
} }
/// Mutable column access for the given column index, or `None` if out of bounds. /// Mutable column access for the given column index, or `None` if out of bounds.
#[inline] #[inline]
pub fn get_col_mut(&mut self, index: usize) -> Option<CscColMut<T>> { pub fn get_col_mut(&mut self, index: usize) -> Option<CscColMut<T>> {
let range = self.get_index_range(index)?; self.cs
Some(CscColMut { .get_lane_mut(index)
nrows: self.nrows(), .map(|lane| CscColMut { lane })
row_indices: &self.sparsity_pattern.minor_indices()[range.clone()],
values: &mut self.values[range]
})
}
/// Internal method for simplifying access to a column's data.
fn get_index_range(&self, col_index: usize) -> Option<Range<usize>> {
let col_begin = *self.sparsity_pattern.major_offsets().get(col_index)?;
let col_end = *self.sparsity_pattern.major_offsets().get(col_index + 1)?;
Some(col_begin .. col_end)
} }
/// An iterator over columns in the matrix. /// An iterator over columns in the matrix.
pub fn col_iter(&self) -> CscColIter<T> { pub fn col_iter(&self) -> CscColIter<T> {
CscColIter { CscColIter {
current_col_idx: 0, lane_iter: CsLaneIter::new(self.pattern().as_ref(), self.values())
matrix: self
} }
} }
/// A mutable iterator over columns in the matrix. /// A mutable iterator over columns in the matrix.
pub fn col_iter_mut(&mut self) -> CscColIterMut<T> { pub fn col_iter_mut(&mut self) -> CscColIterMut<T> {
let (pattern, values) = self.cs.pattern_and_values_mut();
CscColIterMut { CscColIterMut {
current_col_idx: 0, lane_iter: CsLaneIterMut::new(pattern, values)
pattern: &self.sparsity_pattern,
remaining_values: self.values.as_mut_ptr()
} }
} }
/// Returns the underlying vector containing the values for the explicitly stored entries.
pub fn take_values(self) -> Vec<T> {
self.values
}
/// Disassembles the CSC matrix into its underlying offset, index and value arrays. /// Disassembles the CSC matrix into its underlying offset, index and value arrays.
/// ///
/// If the matrix contains the sole reference to the sparsity pattern, /// If the matrix contains the sole reference to the sparsity pattern,
@ -274,19 +251,7 @@ impl<T> CscMatrix<T> {
/// assert_eq!(values2, values); /// assert_eq!(values2, values);
/// ``` /// ```
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) { pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
// Take an Arc to the pattern, which might be the sole reference to the data after self.cs.disassemble()
// taking the values. This is important, because it might let us avoid cloning the data
// further below.
let pattern = self.sparsity_pattern;
let values = self.values;
// Try to take the pattern out of the `Arc` if possible,
// otherwise clone the pattern.
let owned_pattern = Arc::try_unwrap(pattern)
.unwrap_or_else(|arc| SparsityPattern::clone(&*arc));
let (offsets, indices) = owned_pattern.disassemble();
(offsets, indices, values)
} }
/// Returns the underlying sparsity pattern. /// Returns the underlying sparsity pattern.
@ -295,15 +260,14 @@ impl<T> CscMatrix<T> {
/// the same sparsity pattern for multiple matrices without storing the same pattern multiple /// the same sparsity pattern for multiple matrices without storing the same pattern multiple
/// times in memory. /// times in memory.
pub fn pattern(&self) -> &Arc<SparsityPattern> { pub fn pattern(&self) -> &Arc<SparsityPattern> {
&self.sparsity_pattern self.cs.pattern()
} }
/// Reinterprets the CSC matrix as its transpose represented by a CSR matrix. /// Reinterprets the CSC matrix as its transpose represented by a CSR matrix.
/// ///
/// This operation does not touch the CSC data, and is effectively a no-op. /// This operation does not touch the CSC data, and is effectively a no-op.
pub fn transpose_as_csr(self) -> CsrMatrix<T> { pub fn transpose_as_csr(self) -> CsrMatrix<T> {
let pattern = self.sparsity_pattern; let (pattern, values) = self.cs.take_pattern_and_values();
let values = self.values;
CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap() CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap()
} }
} }
@ -422,9 +386,7 @@ impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
/// An immutable representation of a column in a CSC matrix. /// An immutable representation of a column in a CSC matrix.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct CscCol<'a, T> { pub struct CscCol<'a, T> {
nrows: usize, lane: CsLane<'a, T>
row_indices: &'a [usize],
values: &'a [T],
} }
/// A mutable representation of a column in a CSC matrix. /// A mutable representation of a column in a CSC matrix.
@ -433,9 +395,7 @@ pub struct CscCol<'a, T> {
/// to the column cannot be modified. /// to the column cannot be modified.
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct CscColMut<'a, T> { pub struct CscColMut<'a, T> {
nrows: usize, lane: CsLaneMut<'a, T>
row_indices: &'a [usize],
values: &'a mut [T]
} }
/// Implement the methods common to both CscCol and CscColMut /// Implement the methods common to both CscCol and CscColMut
@ -445,25 +405,25 @@ macro_rules! impl_csc_col_common_methods {
/// The number of global rows in the column. /// The number of global rows in the column.
#[inline] #[inline]
pub fn nrows(&self) -> usize { pub fn nrows(&self) -> usize {
self.nrows self.lane.minor_dim
} }
/// The number of non-zeros in this column. /// The number of non-zeros in this column.
#[inline] #[inline]
pub fn nnz(&self) -> usize { pub fn nnz(&self) -> usize {
self.row_indices.len() self.lane.minor_indices.len()
} }
/// The row indices corresponding to explicitly stored entries in this column. /// The row indices corresponding to explicitly stored entries in this column.
#[inline] #[inline]
pub fn row_indices(&self) -> &[usize] { pub fn row_indices(&self) -> &[usize] {
self.row_indices self.lane.minor_indices
} }
/// The values corresponding to explicitly stored entries in this column. /// The values corresponding to explicitly stored entries in this column.
#[inline] #[inline]
pub fn values(&self) -> &[T] { pub fn values(&self) -> &[T] {
self.values self.lane.values
} }
} }
@ -480,8 +440,8 @@ macro_rules! impl_csc_col_common_methods {
pub fn get(&self, global_row_index: usize) -> Option<T> { pub fn get(&self, global_row_index: usize) -> Option<T> {
let local_index = self.row_indices().binary_search(&global_row_index); let local_index = self.row_indices().binary_search(&global_row_index);
if let Ok(local_index) = local_index { if let Ok(local_index) = local_index {
Some(self.values[local_index].clone()) Some(self.values()[local_index].clone())
} else if global_row_index < self.nrows { } else if global_row_index < self.lane.minor_dim {
Some(T::zero()) Some(T::zero())
} else { } else {
None None
@ -497,7 +457,7 @@ impl_csc_col_common_methods!(CscColMut<'a, T>);
impl<'a, T> CscColMut<'a, T> { impl<'a, T> CscColMut<'a, T> {
/// Mutable access to the values corresponding to explicitly stored entries in this column. /// Mutable access to the values corresponding to explicitly stored entries in this column.
pub fn values_mut(&mut self) -> &mut [T] { pub fn values_mut(&mut self) -> &mut [T] {
self.values self.lane.values
} }
/// Provides simultaneous access to row indices and mutable values corresponding to the /// Provides simultaneous access to row indices and mutable values corresponding to the
@ -506,32 +466,28 @@ impl<'a, T> CscColMut<'a, T> {
/// This method primarily facilitates low-level access for methods that process data stored /// This method primarily facilitates low-level access for methods that process data stored
/// in CSC format directly. /// in CSC format directly.
pub fn rows_and_values_mut(&mut self) -> (&[usize], &mut [T]) { pub fn rows_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
(self.row_indices, self.values) (self.lane.minor_indices, self.lane.values)
} }
} }
/// Column iterator for [CscMatrix](struct.CscMatrix.html). /// Column iterator for [CscMatrix](struct.CscMatrix.html).
pub struct CscColIter<'a, T> { pub struct CscColIter<'a, T> {
// The index of the row that will be returned on the next lane_iter: CsLaneIter<'a, T>
current_col_idx: usize,
matrix: &'a CscMatrix<T>
} }
impl<'a, T> Iterator for CscColIter<'a, T> { impl<'a, T> Iterator for CscColIter<'a, T> {
type Item = CscCol<'a, T>; type Item = CscCol<'a, T>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let col = self.matrix.get_col(self.current_col_idx); self.lane_iter
self.current_col_idx += 1; .next()
col .map(|lane| CscCol { lane })
} }
} }
/// Mutable column iterator for [CscMatrix](struct.CscMatrix.html). /// Mutable column iterator for [CscMatrix](struct.CscMatrix.html).
pub struct CscColIterMut<'a, T> { pub struct CscColIterMut<'a, T> {
current_col_idx: usize, lane_iter: CsLaneIterMut<'a, T>
pattern: &'a SparsityPattern,
remaining_values: *mut T,
} }
impl<'a, T> Iterator for CscColIterMut<'a, T> impl<'a, T> Iterator for CscColIterMut<'a, T>
@ -541,27 +497,8 @@ where
type Item = CscColMut<'a, T>; type Item = CscColMut<'a, T>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let lane = self.pattern.get_lane(self.current_col_idx); self.lane_iter
let nrows = self.pattern.minor_dim(); .next()
.map(|lane| CscColMut { lane })
if let Some(row_indices) = lane {
let count = row_indices.len();
// Note: I can't think of any way to construct this iterator without unsafe.
let values_in_row;
unsafe {
values_in_row = &mut *slice_from_raw_parts_mut(self.remaining_values, count);
self.remaining_values = self.remaining_values.add(count);
}
self.current_col_idx += 1;
Some(CscColMut {
nrows,
row_indices,
values: values_in_row
})
} else {
None
}
} }
} }

View File

@ -2,14 +2,13 @@
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut}; use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use crate::cs::{CsMatrix, get_entry_from_slices, get_mut_entry_from_slices, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut};
use nalgebra::Scalar; use nalgebra::Scalar;
use num_traits::Zero; use num_traits::Zero;
use std::sync::Arc; use std::sync::Arc;
use std::slice::{IterMut, Iter}; use std::slice::{IterMut, Iter};
use std::ops::Range;
use std::ptr::slice_from_raw_parts_mut;
/// A CSR representation of a sparse matrix. /// A CSR representation of a sparse matrix.
/// ///
@ -21,29 +20,27 @@ use std::ptr::slice_from_raw_parts_mut;
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsrMatrix<T> { pub struct CsrMatrix<T> {
// Rows are major, cols are minor in the sparsity pattern // Rows are major, cols are minor in the sparsity pattern
sparsity_pattern: Arc<SparsityPattern>, cs: CsMatrix<T>,
values: Vec<T>,
} }
impl<T> CsrMatrix<T> { impl<T> CsrMatrix<T> {
/// Create a zero CSR matrix with no explicitly stored entries. /// Create a zero CSR matrix with no explicitly stored entries.
pub fn new(nrows: usize, ncols: usize) -> Self { pub fn new(nrows: usize, ncols: usize) -> Self {
Self { Self {
sparsity_pattern: Arc::new(SparsityPattern::new(nrows, ncols)), cs: CsMatrix::new(nrows, ncols)
values: vec![],
} }
} }
/// The number of rows in the matrix. /// The number of rows in the matrix.
#[inline] #[inline]
pub fn nrows(&self) -> usize { pub fn nrows(&self) -> usize {
self.sparsity_pattern.major_dim() self.cs.pattern().major_dim()
} }
/// The number of columns in the matrix. /// The number of columns in the matrix.
#[inline] #[inline]
pub fn ncols(&self) -> usize { pub fn ncols(&self) -> usize {
self.sparsity_pattern.minor_dim() self.cs.pattern().minor_dim()
} }
/// The number of non-zeros in the matrix. /// The number of non-zeros in the matrix.
@ -53,31 +50,33 @@ impl<T> CsrMatrix<T> {
/// be zero. Corresponds to the number of entries in the sparsity pattern. /// be zero. Corresponds to the number of entries in the sparsity pattern.
#[inline] #[inline]
pub fn nnz(&self) -> usize { pub fn nnz(&self) -> usize {
self.sparsity_pattern.nnz() self.cs.pattern().nnz()
} }
/// The row offsets defining part of the CSR format. /// The row offsets defining part of the CSR format.
#[inline] #[inline]
pub fn row_offsets(&self) -> &[usize] { pub fn row_offsets(&self) -> &[usize] {
self.sparsity_pattern.major_offsets() let (offsets, _, _) = self.cs.cs_data();
offsets
} }
/// The column indices defining part of the CSR format. /// The column indices defining part of the CSR format.
#[inline] #[inline]
pub fn col_indices(&self) -> &[usize] { pub fn col_indices(&self) -> &[usize] {
self.sparsity_pattern.minor_indices() let (_, indices, _) = self.cs.cs_data();
indices
} }
/// The non-zero values defining part of the CSR format. /// The non-zero values defining part of the CSR format.
#[inline] #[inline]
pub fn values(&self) -> &[T] { pub fn values(&self) -> &[T] {
&self.values self.cs.values()
} }
/// Mutable access to the non-zero values. /// Mutable access to the non-zero values.
#[inline] #[inline]
pub fn values_mut(&mut self) -> &mut [T] { pub fn values_mut(&mut self) -> &mut [T] {
&mut self.values self.cs.values_mut()
} }
/// Try to construct a CSR matrix from raw CSR data. /// Try to construct a CSR matrix from raw CSR data.
@ -109,8 +108,7 @@ impl<T> CsrMatrix<T> {
-> Result<Self, SparseFormatError> { -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() { if pattern.nnz() == values.len() {
Ok(Self { Ok(Self {
sparsity_pattern: pattern, cs: CsMatrix::from_pattern_and_values(pattern, values)
values,
}) })
} else { } else {
Err(SparseFormatError::from_kind_and_msg( Err(SparseFormatError::from_kind_and_msg(
@ -119,7 +117,6 @@ impl<T> CsrMatrix<T> {
} }
} }
/// An iterator over non-zero triplets (i, j, v). /// An iterator over non-zero triplets (i, j, v).
/// ///
/// The iteration happens in row-major fashion, meaning that i increases monotonically, /// The iteration happens in row-major fashion, meaning that i increases monotonically,
@ -140,8 +137,8 @@ impl<T> CsrMatrix<T> {
/// ``` /// ```
pub fn triplet_iter(&self) -> CsrTripletIter<T> { pub fn triplet_iter(&self) -> CsrTripletIter<T> {
CsrTripletIter { CsrTripletIter {
pattern_iter: self.sparsity_pattern.entries(), pattern_iter: self.pattern().entries(),
values_iter: self.values.iter() values_iter: self.values().iter()
} }
} }
@ -169,9 +166,10 @@ impl<T> CsrMatrix<T> {
/// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 0)]); /// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 0)]);
/// ``` /// ```
pub fn triplet_iter_mut(&mut self) -> CsrTripletIterMut<T> { pub fn triplet_iter_mut(&mut self) -> CsrTripletIterMut<T> {
let (pattern, values) = self.cs.pattern_and_values_mut();
CsrTripletIterMut { CsrTripletIterMut {
pattern_iter: self.sparsity_pattern.entries(), pattern_iter: pattern.entries(),
values_mut_iter: self.values.iter_mut() values_mut_iter: values.iter_mut()
} }
} }
@ -200,54 +198,34 @@ impl<T> CsrMatrix<T> {
/// Return the row at the given row index, or `None` if out of bounds. /// Return the row at the given row index, or `None` if out of bounds.
#[inline] #[inline]
pub fn get_row(&self, index: usize) -> Option<CsrRow<T>> { pub fn get_row(&self, index: usize) -> Option<CsrRow<T>> {
let range = self.get_index_range(index)?; self.cs
Some(CsrRow { .get_lane(index)
col_indices: &self.sparsity_pattern.minor_indices()[range.clone()], .map(|lane| CsrRow { lane })
values: &self.values[range],
ncols: self.ncols()
})
} }
/// Mutable row access for the given row index, or `None` if out of bounds. /// Mutable row access for the given row index, or `None` if out of bounds.
#[inline] #[inline]
pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<T>> { pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<T>> {
let range = self.get_index_range(index)?; self.cs
Some(CsrRowMut { .get_lane_mut(index)
ncols: self.ncols(), .map(|lane| CsrRowMut { lane })
col_indices: &self.sparsity_pattern.minor_indices()[range.clone()],
values: &mut self.values[range]
})
}
/// Internal method for simplifying access to a row's data.
fn get_index_range(&self, row_index: usize) -> Option<Range<usize>> {
let row_begin = *self.sparsity_pattern.major_offsets().get(row_index)?;
let row_end = *self.sparsity_pattern.major_offsets().get(row_index + 1)?;
Some(row_begin .. row_end)
} }
/// An iterator over rows in the matrix. /// An iterator over rows in the matrix.
pub fn row_iter(&self) -> CsrRowIter<T> { pub fn row_iter(&self) -> CsrRowIter<T> {
CsrRowIter { CsrRowIter {
current_row_idx: 0, lane_iter: CsLaneIter::new(self.pattern().as_ref(), self.values())
matrix: self
} }
} }
/// A mutable iterator over rows in the matrix. /// A mutable iterator over rows in the matrix.
pub fn row_iter_mut(&mut self) -> CsrRowIterMut<T> { pub fn row_iter_mut(&mut self) -> CsrRowIterMut<T> {
let (pattern, values) = self.cs.pattern_and_values_mut();
CsrRowIterMut { CsrRowIterMut {
current_row_idx: 0, lane_iter: CsLaneIterMut::new(pattern, values),
pattern: &self.sparsity_pattern,
remaining_values: self.values.as_mut_ptr()
} }
} }
/// Returns the underlying vector containing the values for the explicitly stored entries.
pub fn take_values(self) -> Vec<T> {
self.values
}
/// Disassembles the CSR matrix into its underlying offset, index and value arrays. /// Disassembles the CSR matrix into its underlying offset, index and value arrays.
/// ///
/// If the matrix contains the sole reference to the sparsity pattern, /// If the matrix contains the sole reference to the sparsity pattern,
@ -274,19 +252,7 @@ impl<T> CsrMatrix<T> {
/// assert_eq!(values2, values); /// assert_eq!(values2, values);
/// ``` /// ```
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) { pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
// Take an Arc to the pattern, which might be the sole reference to the data after self.cs.disassemble()
// taking the values. This is important, because it might let us avoid cloning the data
// further below.
let pattern = self.sparsity_pattern;
let values = self.values;
// Try to take the pattern out of the `Arc` if possible,
// otherwise clone the pattern.
let owned_pattern = Arc::try_unwrap(pattern)
.unwrap_or_else(|arc| SparsityPattern::clone(&*arc));
let (offsets, indices) = owned_pattern.disassemble();
(offsets, indices, values)
} }
/// Returns the underlying sparsity pattern. /// Returns the underlying sparsity pattern.
@ -295,15 +261,14 @@ impl<T> CsrMatrix<T> {
/// the same sparsity pattern for multiple matrices without storing the same pattern multiple /// the same sparsity pattern for multiple matrices without storing the same pattern multiple
/// times in memory. /// times in memory.
pub fn pattern(&self) -> &Arc<SparsityPattern> { pub fn pattern(&self) -> &Arc<SparsityPattern> {
&self.sparsity_pattern self.cs.pattern()
} }
/// Reinterprets the CSR matrix as its transpose represented by a CSC matrix. /// Reinterprets the CSR matrix as its transpose represented by a CSC matrix.
/// ///
/// This operation does not touch the CSR data, and is effectively a no-op. /// This operation does not touch the CSR data, and is effectively a no-op.
pub fn transpose_as_csc(self) -> CscMatrix<T> { pub fn transpose_as_csc(self) -> CscMatrix<T> {
let pattern = self.sparsity_pattern; let (pattern, values) = self.cs.take_pattern_and_values();
let values = self.values;
CscMatrix::try_from_pattern_and_values(pattern, values).unwrap() CscMatrix::try_from_pattern_and_values(pattern, values).unwrap()
} }
@ -312,10 +277,7 @@ impl<T> CsrMatrix<T> {
/// Each call to this function incurs the cost of a binary search among the explicitly /// Each call to this function incurs the cost of a binary search among the explicitly
/// stored column entries for the given row. /// stored column entries for the given row.
pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<T>> { pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<T>> {
let row_range = self.get_index_range(row_index)?; self.cs.get_entry(row_index, col_index)
let col_indices = &self.col_indices()[row_range.clone()];
let values = &self.values()[row_range];
get_entry_from_slices(self.ncols(), col_indices, values, col_index)
} }
/// Returns a mutable entry for the given row/col indices, or `None` if the indices are out /// Returns a mutable entry for the given row/col indices, or `None` if the indices are out
@ -325,12 +287,7 @@ impl<T> CsrMatrix<T> {
/// stored column entries for the given row. /// stored column entries for the given row.
pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize) pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize)
-> Option<SparseEntryMut<T>> { -> Option<SparseEntryMut<T>> {
let row_range = self.get_index_range(row_index)?; self.cs.get_entry_mut(row_index, col_index)
let ncols = self.ncols();
let (_, col_indices, values) = self.csr_data_mut();
let col_indices = &col_indices[row_range.clone()];
let values = &mut values[row_range];
get_mut_entry_from_slices(ncols, col_indices, values, col_index)
} }
/// Returns an entry for the given row/col indices. /// Returns an entry for the given row/col indices.
@ -361,14 +318,13 @@ impl<T> CsrMatrix<T> {
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data. /// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data.
pub fn csr_data(&self) -> (&[usize], &[usize], &[T]) { pub fn csr_data(&self) -> (&[usize], &[usize], &[T]) {
(self.row_offsets(), self.col_indices(), self.values()) self.cs.cs_data()
} }
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data, /// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data,
/// where the `values` array is mutable. /// where the `values` array is mutable.
pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) { pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
let pattern = self.sparsity_pattern.as_ref(); self.cs.cs_data_mut()
(pattern.major_offsets(), pattern.minor_indices(), &mut self.values)
} }
} }
@ -460,9 +416,7 @@ impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
/// An immutable representation of a row in a CSR matrix. /// An immutable representation of a row in a CSR matrix.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsrRow<'a, T> { pub struct CsrRow<'a, T> {
ncols: usize, lane: CsLane<'a, T>
col_indices: &'a [usize],
values: &'a [T],
} }
/// A mutable representation of a row in a CSR matrix. /// A mutable representation of a row in a CSR matrix.
@ -471,9 +425,7 @@ pub struct CsrRow<'a, T> {
/// to the row cannot be modified. /// to the row cannot be modified.
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct CsrRowMut<'a, T> { pub struct CsrRowMut<'a, T> {
ncols: usize, lane: CsLaneMut<'a, T>
col_indices: &'a [usize],
values: &'a mut [T]
} }
/// Implement the methods common to both CsrRow and CsrRowMut /// Implement the methods common to both CsrRow and CsrRowMut
@ -483,25 +435,25 @@ macro_rules! impl_csr_row_common_methods {
/// The number of global columns in the row. /// The number of global columns in the row.
#[inline] #[inline]
pub fn ncols(&self) -> usize { pub fn ncols(&self) -> usize {
self.ncols self.lane.minor_dim
} }
/// The number of non-zeros in this row. /// The number of non-zeros in this row.
#[inline] #[inline]
pub fn nnz(&self) -> usize { pub fn nnz(&self) -> usize {
self.col_indices.len() self.lane.minor_indices.len()
} }
/// The column indices corresponding to explicitly stored entries in this row. /// The column indices corresponding to explicitly stored entries in this row.
#[inline] #[inline]
pub fn col_indices(&self) -> &[usize] { pub fn col_indices(&self) -> &[usize] {
self.col_indices self.lane.minor_indices
} }
/// The values corresponding to explicitly stored entries in this row. /// The values corresponding to explicitly stored entries in this row.
#[inline] #[inline]
pub fn values(&self) -> &[T] { pub fn values(&self) -> &[T] {
self.values self.lane.values
} }
/// Returns an entry for the given global column index. /// Returns an entry for the given global column index.
@ -509,47 +461,23 @@ macro_rules! impl_csr_row_common_methods {
/// Each call to this function incurs the cost of a binary search among the explicitly /// Each call to this function incurs the cost of a binary search among the explicitly
/// stored column entries. /// stored column entries.
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> { pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
get_entry_from_slices(self.ncols, self.col_indices, self.values, global_col_index) get_entry_from_slices(
self.lane.minor_dim,
self.lane.minor_indices,
self.lane.values,
global_col_index)
} }
} }
} }
} }
fn get_entry_from_slices<'a, T>(ncols: usize,
col_indices: &'a [usize],
values: &'a [T],
global_col_index: usize) -> Option<SparseEntry<'a, T>> {
let local_index = col_indices.binary_search(&global_col_index);
if let Ok(local_index) = local_index {
Some(SparseEntry::NonZero(&values[local_index]))
} else if global_col_index < ncols {
Some(SparseEntry::Zero)
} else {
None
}
}
fn get_mut_entry_from_slices<'a, T>(ncols: usize,
col_indices: &'a [usize],
values: &'a mut [T],
global_col_index: usize) -> Option<SparseEntryMut<'a, T>> {
let local_index = col_indices.binary_search(&global_col_index);
if let Ok(local_index) = local_index {
Some(SparseEntryMut::NonZero(&mut values[local_index]))
} else if global_col_index < ncols {
Some(SparseEntryMut::Zero)
} else {
None
}
}
impl_csr_row_common_methods!(CsrRow<'a, T>); impl_csr_row_common_methods!(CsrRow<'a, T>);
impl_csr_row_common_methods!(CsrRowMut<'a, T>); impl_csr_row_common_methods!(CsrRowMut<'a, T>);
impl<'a, T> CsrRowMut<'a, T> { impl<'a, T> CsrRowMut<'a, T> {
/// Mutable access to the values corresponding to explicitly stored entries in this row. /// Mutable access to the values corresponding to explicitly stored entries in this row.
pub fn values_mut(&mut self) -> &mut [T] { pub fn values_mut(&mut self) -> &mut [T] {
self.values self.lane.values
} }
/// Provides simultaneous access to column indices and mutable values corresponding to the /// Provides simultaneous access to column indices and mutable values corresponding to the
@ -558,37 +486,36 @@ impl<'a, T> CsrRowMut<'a, T> {
/// This method primarily facilitates low-level access for methods that process data stored /// This method primarily facilitates low-level access for methods that process data stored
/// in CSR format directly. /// in CSR format directly.
pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) { pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
(self.col_indices, self.values) (self.lane.minor_indices, self.lane.values)
} }
/// Returns a mutable entry for the given global column index. /// Returns a mutable entry for the given global column index.
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<T>> { pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<T>> {
get_mut_entry_from_slices(self.ncols, self.col_indices, self.values, global_col_index) get_mut_entry_from_slices(self.lane.minor_dim,
self.lane.minor_indices,
self.lane.values,
global_col_index)
} }
} }
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html). /// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
pub struct CsrRowIter<'a, T> { pub struct CsrRowIter<'a, T> {
// The index of the row that will be returned on the next lane_iter: CsLaneIter<'a, T>
current_row_idx: usize,
matrix: &'a CsrMatrix<T>
} }
impl<'a, T> Iterator for CsrRowIter<'a, T> { impl<'a, T> Iterator for CsrRowIter<'a, T> {
type Item = CsrRow<'a, T>; type Item = CsrRow<'a, T>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let row = self.matrix.get_row(self.current_row_idx); self.lane_iter
self.current_row_idx += 1; .next()
row .map(|lane| CsrRow { lane })
} }
} }
/// Mutable row iterator for [CsrMatrix](struct.CsrMatrix.html). /// Mutable row iterator for [CsrMatrix](struct.CsrMatrix.html).
pub struct CsrRowIterMut<'a, T> { pub struct CsrRowIterMut<'a, T> {
current_row_idx: usize, lane_iter: CsLaneIterMut<'a, T>
pattern: &'a SparsityPattern,
remaining_values: *mut T,
} }
impl<'a, T> Iterator for CsrRowIterMut<'a, T> impl<'a, T> Iterator for CsrRowIterMut<'a, T>
@ -598,27 +525,8 @@ where
type Item = CsrRowMut<'a, T>; type Item = CsrRowMut<'a, T>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let lane = self.pattern.get_lane(self.current_row_idx); self.lane_iter
let ncols = self.pattern.minor_dim(); .next()
.map(|lane| CsrRowMut { lane })
if let Some(col_indices) = lane {
let count = col_indices.len();
// Note: I can't think of any way to construct this iterator without unsafe.
let values_in_row;
unsafe {
values_in_row = &mut *slice_from_raw_parts_mut(self.remaining_values, count);
self.remaining_values = self.remaining_values.add(count);
}
self.current_row_idx += 1;
Some(CsrRowMut {
ncols,
col_indices,
values: values_in_row
})
} else {
None
}
} }
} }

View File

@ -90,6 +90,8 @@ pub mod pattern;
pub mod ops; pub mod ops;
pub mod convert; pub mod convert;
mod cs;
#[cfg(feature = "proptest-support")] #[cfg(feature = "proptest-support")]
pub mod proptest; pub mod proptest;