Refactor most of Csr/CscMatrix logic into helper type CsMatrix
Still need to update CSC API so that it mirrors CsrMatrix in terms of get_entry and so on.
This commit is contained in:
parent
8983027b39
commit
b59c4a3216
|
@ -0,0 +1,286 @@
|
||||||
|
use crate::pattern::SparsityPattern;
|
||||||
|
use crate::{SparseEntry, SparseEntryMut};
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::ops::Range;
|
||||||
|
use std::ptr::slice_from_raw_parts_mut;
|
||||||
|
|
||||||
|
/// An abstract compressed matrix.
|
||||||
|
///
|
||||||
|
/// For the time being, this is only used internally to share implementation between
|
||||||
|
/// CSR and CSC matrices.
|
||||||
|
///
|
||||||
|
/// A CSR matrix is obtained by associating rows with the major dimension, while a CSC matrix
|
||||||
|
/// is obtained by associating columns with the major dimension.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct CsMatrix<T> {
|
||||||
|
sparsity_pattern: Arc<SparsityPattern>,
|
||||||
|
values: Vec<T>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> CsMatrix<T> {
|
||||||
|
/// Create a zero matrix with no explicitly stored entries.
|
||||||
|
#[inline]
|
||||||
|
pub fn new(major_dim: usize, minor_dim: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
sparsity_pattern: Arc::new(SparsityPattern::new(major_dim, minor_dim)),
|
||||||
|
values: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
||||||
|
&self.sparsity_pattern
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn values(&self) -> &[T] {
|
||||||
|
&self.values
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
|
&mut self.values
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
||||||
|
#[inline]
|
||||||
|
pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) {
|
||||||
|
let pattern = self.pattern().as_ref();
|
||||||
|
(pattern.major_offsets(), pattern.minor_indices(), &self.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
||||||
|
#[inline]
|
||||||
|
pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||||
|
let pattern = self.sparsity_pattern.as_ref();
|
||||||
|
(pattern.major_offsets(), pattern.minor_indices(), &mut self.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn pattern_and_values_mut(&mut self) -> (&Arc<SparsityPattern>, &mut [T]) {
|
||||||
|
(&self.sparsity_pattern, &mut self.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn from_pattern_and_values(pattern: Arc<SparsityPattern>, values: Vec<T>)
|
||||||
|
-> Self {
|
||||||
|
assert_eq!(pattern.nnz(), values.len(), "Internal error: consumers should verify shape compatibility.");
|
||||||
|
Self {
|
||||||
|
sparsity_pattern: pattern,
|
||||||
|
values,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Internal method for simplifying access to a lane's data
|
||||||
|
#[inline]
|
||||||
|
pub fn get_index_range(&self, row_index: usize) -> Option<Range<usize>> {
|
||||||
|
let row_begin = *self.sparsity_pattern.major_offsets().get(row_index)?;
|
||||||
|
let row_end = *self.sparsity_pattern.major_offsets().get(row_index + 1)?;
|
||||||
|
Some(row_begin .. row_end)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn take_pattern_and_values(self) -> (Arc<SparsityPattern>, Vec<T>) {
|
||||||
|
(self.sparsity_pattern, self.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
||||||
|
// Take an Arc to the pattern, which might be the sole reference to the data after
|
||||||
|
// taking the values. This is important, because it might let us avoid cloning the data
|
||||||
|
// further below.
|
||||||
|
let pattern = self.sparsity_pattern;
|
||||||
|
let values = self.values;
|
||||||
|
|
||||||
|
// Try to take the pattern out of the `Arc` if possible,
|
||||||
|
// otherwise clone the pattern.
|
||||||
|
let owned_pattern = Arc::try_unwrap(pattern)
|
||||||
|
.unwrap_or_else(|arc| SparsityPattern::clone(&*arc));
|
||||||
|
let (offsets, indices) = owned_pattern.disassemble();
|
||||||
|
|
||||||
|
(offsets, indices, values)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an entry for the given major/minor indices, or `None` if the indices are out
|
||||||
|
/// of bounds.
|
||||||
|
pub fn get_entry(&self, major_index: usize, minor_index: usize) -> Option<SparseEntry<T>> {
|
||||||
|
let row_range = self.get_index_range(major_index)?;
|
||||||
|
let (_, minor_indices, values) = self.cs_data();
|
||||||
|
let minor_indices = &minor_indices[row_range.clone()];
|
||||||
|
let values = &values[row_range];
|
||||||
|
get_entry_from_slices(self.pattern().minor_dim(), minor_indices, values, minor_index)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable entry for the given major/minor indices, or `None` if the indices are out
|
||||||
|
/// of bounds.
|
||||||
|
pub fn get_entry_mut(&mut self, major_index: usize, minor_index: usize)
|
||||||
|
-> Option<SparseEntryMut<T>> {
|
||||||
|
let row_range = self.get_index_range(major_index)?;
|
||||||
|
let minor_dim = self.pattern().minor_dim();
|
||||||
|
let (_, minor_indices, values) = self.cs_data_mut();
|
||||||
|
let minor_indices = &minor_indices[row_range.clone()];
|
||||||
|
let values = &mut values[row_range];
|
||||||
|
get_mut_entry_from_slices(minor_dim, minor_indices, values, minor_index)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_lane(&self, index: usize) -> Option<CsLane<T>> {
|
||||||
|
let range = self.get_index_range(index)?;
|
||||||
|
let (_, minor_indices, values) = self.cs_data();
|
||||||
|
Some(CsLane {
|
||||||
|
minor_indices: &minor_indices[range.clone()],
|
||||||
|
values: &values[range],
|
||||||
|
minor_dim: self.pattern().minor_dim()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn get_lane_mut(&mut self, index: usize) -> Option<CsLaneMut<T>> {
|
||||||
|
let range = self.get_index_range(index)?;
|
||||||
|
let minor_dim = self.pattern().minor_dim();
|
||||||
|
let (_, minor_indices, values) = self.cs_data_mut();
|
||||||
|
Some(CsLaneMut {
|
||||||
|
minor_dim,
|
||||||
|
minor_indices: &minor_indices[range.clone()],
|
||||||
|
values: &mut values[range]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_entry_from_slices<'a, T>(
|
||||||
|
minor_dim: usize,
|
||||||
|
minor_indices: &'a [usize],
|
||||||
|
values: &'a [T],
|
||||||
|
global_minor_index: usize) -> Option<SparseEntry<'a, T>> {
|
||||||
|
let local_index = minor_indices.binary_search(&global_minor_index);
|
||||||
|
if let Ok(local_index) = local_index {
|
||||||
|
Some(SparseEntry::NonZero(&values[local_index]))
|
||||||
|
} else if global_minor_index < minor_dim {
|
||||||
|
Some(SparseEntry::Zero)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_mut_entry_from_slices<'a, T>(
|
||||||
|
minor_dim: usize,
|
||||||
|
minor_indices: &'a [usize],
|
||||||
|
values: &'a mut [T],
|
||||||
|
global_minor_indices: usize) -> Option<SparseEntryMut<'a, T>> {
|
||||||
|
let local_index = minor_indices.binary_search(&global_minor_indices);
|
||||||
|
if let Ok(local_index) = local_index {
|
||||||
|
Some(SparseEntryMut::NonZero(&mut values[local_index]))
|
||||||
|
} else if global_minor_indices < minor_dim {
|
||||||
|
Some(SparseEntryMut::Zero)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct CsLane<'a, T> {
|
||||||
|
pub minor_dim: usize,
|
||||||
|
pub minor_indices: &'a [usize],
|
||||||
|
pub values: &'a [T]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub struct CsLaneMut<'a, T> {
|
||||||
|
pub minor_dim: usize,
|
||||||
|
pub minor_indices: &'a [usize],
|
||||||
|
pub values: &'a mut [T]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CsLaneIter<'a, T> {
|
||||||
|
// The index of the lane that will be returned on the next iteration
|
||||||
|
current_lane_idx: usize,
|
||||||
|
pattern: &'a SparsityPattern,
|
||||||
|
remaining_values: &'a [T],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> CsLaneIter<'a, T> {
|
||||||
|
pub fn new(pattern: &'a SparsityPattern, values: &'a [T]) -> Self {
|
||||||
|
Self {
|
||||||
|
current_lane_idx: 0,
|
||||||
|
pattern,
|
||||||
|
remaining_values: values
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CsLaneIter<'a, T>
|
||||||
|
where
|
||||||
|
T: 'a
|
||||||
|
{
|
||||||
|
type Item = CsLane<'a, T>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let lane = self.pattern.get_lane(self.current_lane_idx);
|
||||||
|
let minor_dim = self.pattern.minor_dim();
|
||||||
|
|
||||||
|
if let Some(minor_indices) = lane {
|
||||||
|
let count = minor_indices.len();
|
||||||
|
let values_in_lane = &self.remaining_values[..count];
|
||||||
|
self.remaining_values = &self.remaining_values[count ..];
|
||||||
|
self.current_lane_idx += 1;
|
||||||
|
|
||||||
|
Some(CsLane {
|
||||||
|
minor_dim,
|
||||||
|
minor_indices,
|
||||||
|
values: values_in_lane
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CsLaneIterMut<'a, T> {
|
||||||
|
// The index of the lane that will be returned on the next iteration
|
||||||
|
current_lane_idx: usize,
|
||||||
|
pattern: &'a SparsityPattern,
|
||||||
|
remaining_values: *mut T,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> CsLaneIterMut<'a, T> {
|
||||||
|
pub fn new(pattern: &'a SparsityPattern, values: &'a mut [T]) -> Self {
|
||||||
|
Self {
|
||||||
|
current_lane_idx: 0,
|
||||||
|
pattern,
|
||||||
|
remaining_values: values.as_mut_ptr()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CsLaneIterMut<'a, T>
|
||||||
|
where
|
||||||
|
T: 'a
|
||||||
|
{
|
||||||
|
type Item = CsLaneMut<'a, T>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let lane = self.pattern.get_lane(self.current_lane_idx);
|
||||||
|
let minor_dim = self.pattern.minor_dim();
|
||||||
|
|
||||||
|
if let Some(minor_indices) = lane {
|
||||||
|
let count = minor_indices.len();
|
||||||
|
|
||||||
|
// Note: I can't think of any way to construct this iterator without unsafe.
|
||||||
|
let values_in_lane;
|
||||||
|
unsafe {
|
||||||
|
values_in_lane = &mut *slice_from_raw_parts_mut(self.remaining_values, count);
|
||||||
|
self.remaining_values = self.remaining_values.add(count);
|
||||||
|
}
|
||||||
|
self.current_lane_idx += 1;
|
||||||
|
|
||||||
|
Some(CsLaneMut {
|
||||||
|
minor_dim,
|
||||||
|
minor_indices,
|
||||||
|
values: values_in_lane
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,13 +2,12 @@
|
||||||
|
|
||||||
use crate::{SparseFormatError, SparseFormatErrorKind};
|
use crate::{SparseFormatError, SparseFormatErrorKind};
|
||||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
|
use crate::csr::CsrMatrix;
|
||||||
|
use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut};
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::slice::{IterMut, Iter};
|
use std::slice::{IterMut, Iter};
|
||||||
use std::ops::Range;
|
|
||||||
use num_traits::Zero;
|
use num_traits::Zero;
|
||||||
use std::ptr::slice_from_raw_parts_mut;
|
|
||||||
use crate::csr::CsrMatrix;
|
|
||||||
use nalgebra::Scalar;
|
use nalgebra::Scalar;
|
||||||
|
|
||||||
/// A CSC representation of a sparse matrix.
|
/// A CSC representation of a sparse matrix.
|
||||||
|
@ -21,29 +20,27 @@ use nalgebra::Scalar;
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct CscMatrix<T> {
|
pub struct CscMatrix<T> {
|
||||||
// Cols are major, rows are minor in the sparsity pattern
|
// Cols are major, rows are minor in the sparsity pattern
|
||||||
sparsity_pattern: Arc<SparsityPattern>,
|
cs: CsMatrix<T>,
|
||||||
values: Vec<T>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> CscMatrix<T> {
|
impl<T> CscMatrix<T> {
|
||||||
/// Create a zero CSC matrix with no explicitly stored entries.
|
/// Create a zero CSC matrix with no explicitly stored entries.
|
||||||
pub fn new(nrows: usize, ncols: usize) -> Self {
|
pub fn new(nrows: usize, ncols: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
sparsity_pattern: Arc::new(SparsityPattern::new(ncols, nrows)),
|
cs: CsMatrix::new(ncols, nrows)
|
||||||
values: vec![],
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The number of rows in the matrix.
|
/// The number of rows in the matrix.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn nrows(&self) -> usize {
|
pub fn nrows(&self) -> usize {
|
||||||
self.sparsity_pattern.minor_dim()
|
self.cs.pattern().minor_dim()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The number of columns in the matrix.
|
/// The number of columns in the matrix.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn ncols(&self) -> usize {
|
pub fn ncols(&self) -> usize {
|
||||||
self.sparsity_pattern.major_dim()
|
self.cs.pattern().major_dim()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The number of non-zeros in the matrix.
|
/// The number of non-zeros in the matrix.
|
||||||
|
@ -53,31 +50,31 @@ impl<T> CscMatrix<T> {
|
||||||
/// be zero. Corresponds to the number of entries in the sparsity pattern.
|
/// be zero. Corresponds to the number of entries in the sparsity pattern.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn nnz(&self) -> usize {
|
pub fn nnz(&self) -> usize {
|
||||||
self.sparsity_pattern.nnz()
|
self.pattern().nnz()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The column offsets defining part of the CSC format.
|
/// The column offsets defining part of the CSC format.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn col_offsets(&self) -> &[usize] {
|
pub fn col_offsets(&self) -> &[usize] {
|
||||||
self.sparsity_pattern.major_offsets()
|
self.pattern().major_offsets()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The row indices defining part of the CSC format.
|
/// The row indices defining part of the CSC format.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn row_indices(&self) -> &[usize] {
|
pub fn row_indices(&self) -> &[usize] {
|
||||||
self.sparsity_pattern.minor_indices()
|
self.pattern().minor_indices()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The non-zero values defining part of the CSC format.
|
/// The non-zero values defining part of the CSC format.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn values(&self) -> &[T] {
|
pub fn values(&self) -> &[T] {
|
||||||
&self.values
|
self.cs.values()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mutable access to the non-zero values.
|
/// Mutable access to the non-zero values.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn values_mut(&mut self) -> &mut [T] {
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
&mut self.values
|
self.cs.values_mut()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Try to construct a CSC matrix from raw CSC data.
|
/// Try to construct a CSC matrix from raw CSC data.
|
||||||
|
@ -109,8 +106,7 @@ impl<T> CscMatrix<T> {
|
||||||
-> Result<Self, SparseFormatError> {
|
-> Result<Self, SparseFormatError> {
|
||||||
if pattern.nnz() == values.len() {
|
if pattern.nnz() == values.len() {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
sparsity_pattern: pattern,
|
cs: CsMatrix::from_pattern_and_values(pattern, values)
|
||||||
values,
|
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
Err(SparseFormatError::from_kind_and_msg(
|
Err(SparseFormatError::from_kind_and_msg(
|
||||||
|
@ -140,8 +136,8 @@ impl<T> CscMatrix<T> {
|
||||||
/// ```
|
/// ```
|
||||||
pub fn triplet_iter(&self) -> CscTripletIter<T> {
|
pub fn triplet_iter(&self) -> CscTripletIter<T> {
|
||||||
CscTripletIter {
|
CscTripletIter {
|
||||||
pattern_iter: self.sparsity_pattern.entries(),
|
pattern_iter: self.pattern().entries(),
|
||||||
values_iter: self.values.iter()
|
values_iter: self.values().iter()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,9 +165,10 @@ impl<T> CscMatrix<T> {
|
||||||
/// assert_eq!(triplets, vec![(0, 0, 1), (2, 0, 0), (1, 1, 2), (0, 2, 4)]);
|
/// assert_eq!(triplets, vec![(0, 0, 1), (2, 0, 0), (1, 1, 2), (0, 2, 4)]);
|
||||||
/// ```
|
/// ```
|
||||||
pub fn triplet_iter_mut(&mut self) -> CscTripletIterMut<T> {
|
pub fn triplet_iter_mut(&mut self) -> CscTripletIterMut<T> {
|
||||||
|
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||||
CscTripletIterMut {
|
CscTripletIterMut {
|
||||||
pattern_iter: self.sparsity_pattern.entries(),
|
pattern_iter: pattern.entries(),
|
||||||
values_mut_iter: self.values.iter_mut()
|
values_mut_iter: values.iter_mut()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -200,54 +197,34 @@ impl<T> CscMatrix<T> {
|
||||||
/// Return the column at the given column index, or `None` if out of bounds.
|
/// Return the column at the given column index, or `None` if out of bounds.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn get_col(&self, index: usize) -> Option<CscCol<T>> {
|
pub fn get_col(&self, index: usize) -> Option<CscCol<T>> {
|
||||||
let range = self.get_index_range(index)?;
|
self.cs
|
||||||
Some(CscCol {
|
.get_lane(index)
|
||||||
row_indices: &self.sparsity_pattern.minor_indices()[range.clone()],
|
.map(|lane| CscCol { lane })
|
||||||
values: &self.values[range],
|
|
||||||
nrows: self.nrows()
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mutable column access for the given column index, or `None` if out of bounds.
|
/// Mutable column access for the given column index, or `None` if out of bounds.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn get_col_mut(&mut self, index: usize) -> Option<CscColMut<T>> {
|
pub fn get_col_mut(&mut self, index: usize) -> Option<CscColMut<T>> {
|
||||||
let range = self.get_index_range(index)?;
|
self.cs
|
||||||
Some(CscColMut {
|
.get_lane_mut(index)
|
||||||
nrows: self.nrows(),
|
.map(|lane| CscColMut { lane })
|
||||||
row_indices: &self.sparsity_pattern.minor_indices()[range.clone()],
|
|
||||||
values: &mut self.values[range]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Internal method for simplifying access to a column's data.
|
|
||||||
fn get_index_range(&self, col_index: usize) -> Option<Range<usize>> {
|
|
||||||
let col_begin = *self.sparsity_pattern.major_offsets().get(col_index)?;
|
|
||||||
let col_end = *self.sparsity_pattern.major_offsets().get(col_index + 1)?;
|
|
||||||
Some(col_begin .. col_end)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An iterator over columns in the matrix.
|
/// An iterator over columns in the matrix.
|
||||||
pub fn col_iter(&self) -> CscColIter<T> {
|
pub fn col_iter(&self) -> CscColIter<T> {
|
||||||
CscColIter {
|
CscColIter {
|
||||||
current_col_idx: 0,
|
lane_iter: CsLaneIter::new(self.pattern().as_ref(), self.values())
|
||||||
matrix: self
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A mutable iterator over columns in the matrix.
|
/// A mutable iterator over columns in the matrix.
|
||||||
pub fn col_iter_mut(&mut self) -> CscColIterMut<T> {
|
pub fn col_iter_mut(&mut self) -> CscColIterMut<T> {
|
||||||
|
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||||
CscColIterMut {
|
CscColIterMut {
|
||||||
current_col_idx: 0,
|
lane_iter: CsLaneIterMut::new(pattern, values)
|
||||||
pattern: &self.sparsity_pattern,
|
|
||||||
remaining_values: self.values.as_mut_ptr()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the underlying vector containing the values for the explicitly stored entries.
|
|
||||||
pub fn take_values(self) -> Vec<T> {
|
|
||||||
self.values
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Disassembles the CSC matrix into its underlying offset, index and value arrays.
|
/// Disassembles the CSC matrix into its underlying offset, index and value arrays.
|
||||||
///
|
///
|
||||||
/// If the matrix contains the sole reference to the sparsity pattern,
|
/// If the matrix contains the sole reference to the sparsity pattern,
|
||||||
|
@ -274,19 +251,7 @@ impl<T> CscMatrix<T> {
|
||||||
/// assert_eq!(values2, values);
|
/// assert_eq!(values2, values);
|
||||||
/// ```
|
/// ```
|
||||||
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
||||||
// Take an Arc to the pattern, which might be the sole reference to the data after
|
self.cs.disassemble()
|
||||||
// taking the values. This is important, because it might let us avoid cloning the data
|
|
||||||
// further below.
|
|
||||||
let pattern = self.sparsity_pattern;
|
|
||||||
let values = self.values;
|
|
||||||
|
|
||||||
// Try to take the pattern out of the `Arc` if possible,
|
|
||||||
// otherwise clone the pattern.
|
|
||||||
let owned_pattern = Arc::try_unwrap(pattern)
|
|
||||||
.unwrap_or_else(|arc| SparsityPattern::clone(&*arc));
|
|
||||||
let (offsets, indices) = owned_pattern.disassemble();
|
|
||||||
|
|
||||||
(offsets, indices, values)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the underlying sparsity pattern.
|
/// Returns the underlying sparsity pattern.
|
||||||
|
@ -295,15 +260,14 @@ impl<T> CscMatrix<T> {
|
||||||
/// the same sparsity pattern for multiple matrices without storing the same pattern multiple
|
/// the same sparsity pattern for multiple matrices without storing the same pattern multiple
|
||||||
/// times in memory.
|
/// times in memory.
|
||||||
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
||||||
&self.sparsity_pattern
|
self.cs.pattern()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Reinterprets the CSC matrix as its transpose represented by a CSR matrix.
|
/// Reinterprets the CSC matrix as its transpose represented by a CSR matrix.
|
||||||
///
|
///
|
||||||
/// This operation does not touch the CSC data, and is effectively a no-op.
|
/// This operation does not touch the CSC data, and is effectively a no-op.
|
||||||
pub fn transpose_as_csr(self) -> CsrMatrix<T> {
|
pub fn transpose_as_csr(self) -> CsrMatrix<T> {
|
||||||
let pattern = self.sparsity_pattern;
|
let (pattern, values) = self.cs.take_pattern_and_values();
|
||||||
let values = self.values;
|
|
||||||
CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -422,9 +386,7 @@ impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
|
||||||
/// An immutable representation of a column in a CSC matrix.
|
/// An immutable representation of a column in a CSC matrix.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct CscCol<'a, T> {
|
pub struct CscCol<'a, T> {
|
||||||
nrows: usize,
|
lane: CsLane<'a, T>
|
||||||
row_indices: &'a [usize],
|
|
||||||
values: &'a [T],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A mutable representation of a column in a CSC matrix.
|
/// A mutable representation of a column in a CSC matrix.
|
||||||
|
@ -433,9 +395,7 @@ pub struct CscCol<'a, T> {
|
||||||
/// to the column cannot be modified.
|
/// to the column cannot be modified.
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub struct CscColMut<'a, T> {
|
pub struct CscColMut<'a, T> {
|
||||||
nrows: usize,
|
lane: CsLaneMut<'a, T>
|
||||||
row_indices: &'a [usize],
|
|
||||||
values: &'a mut [T]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implement the methods common to both CscCol and CscColMut
|
/// Implement the methods common to both CscCol and CscColMut
|
||||||
|
@ -445,25 +405,25 @@ macro_rules! impl_csc_col_common_methods {
|
||||||
/// The number of global rows in the column.
|
/// The number of global rows in the column.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn nrows(&self) -> usize {
|
pub fn nrows(&self) -> usize {
|
||||||
self.nrows
|
self.lane.minor_dim
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The number of non-zeros in this column.
|
/// The number of non-zeros in this column.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn nnz(&self) -> usize {
|
pub fn nnz(&self) -> usize {
|
||||||
self.row_indices.len()
|
self.lane.minor_indices.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The row indices corresponding to explicitly stored entries in this column.
|
/// The row indices corresponding to explicitly stored entries in this column.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn row_indices(&self) -> &[usize] {
|
pub fn row_indices(&self) -> &[usize] {
|
||||||
self.row_indices
|
self.lane.minor_indices
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The values corresponding to explicitly stored entries in this column.
|
/// The values corresponding to explicitly stored entries in this column.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn values(&self) -> &[T] {
|
pub fn values(&self) -> &[T] {
|
||||||
self.values
|
self.lane.values
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -480,8 +440,8 @@ macro_rules! impl_csc_col_common_methods {
|
||||||
pub fn get(&self, global_row_index: usize) -> Option<T> {
|
pub fn get(&self, global_row_index: usize) -> Option<T> {
|
||||||
let local_index = self.row_indices().binary_search(&global_row_index);
|
let local_index = self.row_indices().binary_search(&global_row_index);
|
||||||
if let Ok(local_index) = local_index {
|
if let Ok(local_index) = local_index {
|
||||||
Some(self.values[local_index].clone())
|
Some(self.values()[local_index].clone())
|
||||||
} else if global_row_index < self.nrows {
|
} else if global_row_index < self.lane.minor_dim {
|
||||||
Some(T::zero())
|
Some(T::zero())
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
@ -497,7 +457,7 @@ impl_csc_col_common_methods!(CscColMut<'a, T>);
|
||||||
impl<'a, T> CscColMut<'a, T> {
|
impl<'a, T> CscColMut<'a, T> {
|
||||||
/// Mutable access to the values corresponding to explicitly stored entries in this column.
|
/// Mutable access to the values corresponding to explicitly stored entries in this column.
|
||||||
pub fn values_mut(&mut self) -> &mut [T] {
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
self.values
|
self.lane.values
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Provides simultaneous access to row indices and mutable values corresponding to the
|
/// Provides simultaneous access to row indices and mutable values corresponding to the
|
||||||
|
@ -506,32 +466,28 @@ impl<'a, T> CscColMut<'a, T> {
|
||||||
/// This method primarily facilitates low-level access for methods that process data stored
|
/// This method primarily facilitates low-level access for methods that process data stored
|
||||||
/// in CSC format directly.
|
/// in CSC format directly.
|
||||||
pub fn rows_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
pub fn rows_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
||||||
(self.row_indices, self.values)
|
(self.lane.minor_indices, self.lane.values)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Column iterator for [CscMatrix](struct.CscMatrix.html).
|
/// Column iterator for [CscMatrix](struct.CscMatrix.html).
|
||||||
pub struct CscColIter<'a, T> {
|
pub struct CscColIter<'a, T> {
|
||||||
// The index of the row that will be returned on the next
|
lane_iter: CsLaneIter<'a, T>
|
||||||
current_col_idx: usize,
|
|
||||||
matrix: &'a CscMatrix<T>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> Iterator for CscColIter<'a, T> {
|
impl<'a, T> Iterator for CscColIter<'a, T> {
|
||||||
type Item = CscCol<'a, T>;
|
type Item = CscCol<'a, T>;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
let col = self.matrix.get_col(self.current_col_idx);
|
self.lane_iter
|
||||||
self.current_col_idx += 1;
|
.next()
|
||||||
col
|
.map(|lane| CscCol { lane })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mutable column iterator for [CscMatrix](struct.CscMatrix.html).
|
/// Mutable column iterator for [CscMatrix](struct.CscMatrix.html).
|
||||||
pub struct CscColIterMut<'a, T> {
|
pub struct CscColIterMut<'a, T> {
|
||||||
current_col_idx: usize,
|
lane_iter: CsLaneIterMut<'a, T>
|
||||||
pattern: &'a SparsityPattern,
|
|
||||||
remaining_values: *mut T,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> Iterator for CscColIterMut<'a, T>
|
impl<'a, T> Iterator for CscColIterMut<'a, T>
|
||||||
|
@ -541,27 +497,8 @@ where
|
||||||
type Item = CscColMut<'a, T>;
|
type Item = CscColMut<'a, T>;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
let lane = self.pattern.get_lane(self.current_col_idx);
|
self.lane_iter
|
||||||
let nrows = self.pattern.minor_dim();
|
.next()
|
||||||
|
.map(|lane| CscColMut { lane })
|
||||||
if let Some(row_indices) = lane {
|
|
||||||
let count = row_indices.len();
|
|
||||||
|
|
||||||
// Note: I can't think of any way to construct this iterator without unsafe.
|
|
||||||
let values_in_row;
|
|
||||||
unsafe {
|
|
||||||
values_in_row = &mut *slice_from_raw_parts_mut(self.remaining_values, count);
|
|
||||||
self.remaining_values = self.remaining_values.add(count);
|
|
||||||
}
|
|
||||||
self.current_col_idx += 1;
|
|
||||||
|
|
||||||
Some(CscColMut {
|
|
||||||
nrows,
|
|
||||||
row_indices,
|
|
||||||
values: values_in_row
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -2,14 +2,13 @@
|
||||||
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
|
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
|
||||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
|
use crate::cs::{CsMatrix, get_entry_from_slices, get_mut_entry_from_slices, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut};
|
||||||
|
|
||||||
use nalgebra::Scalar;
|
use nalgebra::Scalar;
|
||||||
use num_traits::Zero;
|
use num_traits::Zero;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::slice::{IterMut, Iter};
|
use std::slice::{IterMut, Iter};
|
||||||
use std::ops::Range;
|
|
||||||
use std::ptr::slice_from_raw_parts_mut;
|
|
||||||
|
|
||||||
/// A CSR representation of a sparse matrix.
|
/// A CSR representation of a sparse matrix.
|
||||||
///
|
///
|
||||||
|
@ -21,29 +20,27 @@ use std::ptr::slice_from_raw_parts_mut;
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct CsrMatrix<T> {
|
pub struct CsrMatrix<T> {
|
||||||
// Rows are major, cols are minor in the sparsity pattern
|
// Rows are major, cols are minor in the sparsity pattern
|
||||||
sparsity_pattern: Arc<SparsityPattern>,
|
cs: CsMatrix<T>,
|
||||||
values: Vec<T>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> CsrMatrix<T> {
|
impl<T> CsrMatrix<T> {
|
||||||
/// Create a zero CSR matrix with no explicitly stored entries.
|
/// Create a zero CSR matrix with no explicitly stored entries.
|
||||||
pub fn new(nrows: usize, ncols: usize) -> Self {
|
pub fn new(nrows: usize, ncols: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
sparsity_pattern: Arc::new(SparsityPattern::new(nrows, ncols)),
|
cs: CsMatrix::new(nrows, ncols)
|
||||||
values: vec![],
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The number of rows in the matrix.
|
/// The number of rows in the matrix.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn nrows(&self) -> usize {
|
pub fn nrows(&self) -> usize {
|
||||||
self.sparsity_pattern.major_dim()
|
self.cs.pattern().major_dim()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The number of columns in the matrix.
|
/// The number of columns in the matrix.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn ncols(&self) -> usize {
|
pub fn ncols(&self) -> usize {
|
||||||
self.sparsity_pattern.minor_dim()
|
self.cs.pattern().minor_dim()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The number of non-zeros in the matrix.
|
/// The number of non-zeros in the matrix.
|
||||||
|
@ -53,31 +50,33 @@ impl<T> CsrMatrix<T> {
|
||||||
/// be zero. Corresponds to the number of entries in the sparsity pattern.
|
/// be zero. Corresponds to the number of entries in the sparsity pattern.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn nnz(&self) -> usize {
|
pub fn nnz(&self) -> usize {
|
||||||
self.sparsity_pattern.nnz()
|
self.cs.pattern().nnz()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The row offsets defining part of the CSR format.
|
/// The row offsets defining part of the CSR format.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn row_offsets(&self) -> &[usize] {
|
pub fn row_offsets(&self) -> &[usize] {
|
||||||
self.sparsity_pattern.major_offsets()
|
let (offsets, _, _) = self.cs.cs_data();
|
||||||
|
offsets
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The column indices defining part of the CSR format.
|
/// The column indices defining part of the CSR format.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn col_indices(&self) -> &[usize] {
|
pub fn col_indices(&self) -> &[usize] {
|
||||||
self.sparsity_pattern.minor_indices()
|
let (_, indices, _) = self.cs.cs_data();
|
||||||
|
indices
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The non-zero values defining part of the CSR format.
|
/// The non-zero values defining part of the CSR format.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn values(&self) -> &[T] {
|
pub fn values(&self) -> &[T] {
|
||||||
&self.values
|
self.cs.values()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mutable access to the non-zero values.
|
/// Mutable access to the non-zero values.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn values_mut(&mut self) -> &mut [T] {
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
&mut self.values
|
self.cs.values_mut()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Try to construct a CSR matrix from raw CSR data.
|
/// Try to construct a CSR matrix from raw CSR data.
|
||||||
|
@ -109,8 +108,7 @@ impl<T> CsrMatrix<T> {
|
||||||
-> Result<Self, SparseFormatError> {
|
-> Result<Self, SparseFormatError> {
|
||||||
if pattern.nnz() == values.len() {
|
if pattern.nnz() == values.len() {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
sparsity_pattern: pattern,
|
cs: CsMatrix::from_pattern_and_values(pattern, values)
|
||||||
values,
|
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
Err(SparseFormatError::from_kind_and_msg(
|
Err(SparseFormatError::from_kind_and_msg(
|
||||||
|
@ -119,7 +117,6 @@ impl<T> CsrMatrix<T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// An iterator over non-zero triplets (i, j, v).
|
/// An iterator over non-zero triplets (i, j, v).
|
||||||
///
|
///
|
||||||
/// The iteration happens in row-major fashion, meaning that i increases monotonically,
|
/// The iteration happens in row-major fashion, meaning that i increases monotonically,
|
||||||
|
@ -140,8 +137,8 @@ impl<T> CsrMatrix<T> {
|
||||||
/// ```
|
/// ```
|
||||||
pub fn triplet_iter(&self) -> CsrTripletIter<T> {
|
pub fn triplet_iter(&self) -> CsrTripletIter<T> {
|
||||||
CsrTripletIter {
|
CsrTripletIter {
|
||||||
pattern_iter: self.sparsity_pattern.entries(),
|
pattern_iter: self.pattern().entries(),
|
||||||
values_iter: self.values.iter()
|
values_iter: self.values().iter()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,9 +166,10 @@ impl<T> CsrMatrix<T> {
|
||||||
/// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 0)]);
|
/// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 0)]);
|
||||||
/// ```
|
/// ```
|
||||||
pub fn triplet_iter_mut(&mut self) -> CsrTripletIterMut<T> {
|
pub fn triplet_iter_mut(&mut self) -> CsrTripletIterMut<T> {
|
||||||
|
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||||
CsrTripletIterMut {
|
CsrTripletIterMut {
|
||||||
pattern_iter: self.sparsity_pattern.entries(),
|
pattern_iter: pattern.entries(),
|
||||||
values_mut_iter: self.values.iter_mut()
|
values_mut_iter: values.iter_mut()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -200,54 +198,34 @@ impl<T> CsrMatrix<T> {
|
||||||
/// Return the row at the given row index, or `None` if out of bounds.
|
/// Return the row at the given row index, or `None` if out of bounds.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn get_row(&self, index: usize) -> Option<CsrRow<T>> {
|
pub fn get_row(&self, index: usize) -> Option<CsrRow<T>> {
|
||||||
let range = self.get_index_range(index)?;
|
self.cs
|
||||||
Some(CsrRow {
|
.get_lane(index)
|
||||||
col_indices: &self.sparsity_pattern.minor_indices()[range.clone()],
|
.map(|lane| CsrRow { lane })
|
||||||
values: &self.values[range],
|
|
||||||
ncols: self.ncols()
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mutable row access for the given row index, or `None` if out of bounds.
|
/// Mutable row access for the given row index, or `None` if out of bounds.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<T>> {
|
pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<T>> {
|
||||||
let range = self.get_index_range(index)?;
|
self.cs
|
||||||
Some(CsrRowMut {
|
.get_lane_mut(index)
|
||||||
ncols: self.ncols(),
|
.map(|lane| CsrRowMut { lane })
|
||||||
col_indices: &self.sparsity_pattern.minor_indices()[range.clone()],
|
|
||||||
values: &mut self.values[range]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Internal method for simplifying access to a row's data.
|
|
||||||
fn get_index_range(&self, row_index: usize) -> Option<Range<usize>> {
|
|
||||||
let row_begin = *self.sparsity_pattern.major_offsets().get(row_index)?;
|
|
||||||
let row_end = *self.sparsity_pattern.major_offsets().get(row_index + 1)?;
|
|
||||||
Some(row_begin .. row_end)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An iterator over rows in the matrix.
|
/// An iterator over rows in the matrix.
|
||||||
pub fn row_iter(&self) -> CsrRowIter<T> {
|
pub fn row_iter(&self) -> CsrRowIter<T> {
|
||||||
CsrRowIter {
|
CsrRowIter {
|
||||||
current_row_idx: 0,
|
lane_iter: CsLaneIter::new(self.pattern().as_ref(), self.values())
|
||||||
matrix: self
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A mutable iterator over rows in the matrix.
|
/// A mutable iterator over rows in the matrix.
|
||||||
pub fn row_iter_mut(&mut self) -> CsrRowIterMut<T> {
|
pub fn row_iter_mut(&mut self) -> CsrRowIterMut<T> {
|
||||||
|
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||||
CsrRowIterMut {
|
CsrRowIterMut {
|
||||||
current_row_idx: 0,
|
lane_iter: CsLaneIterMut::new(pattern, values),
|
||||||
pattern: &self.sparsity_pattern,
|
|
||||||
remaining_values: self.values.as_mut_ptr()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the underlying vector containing the values for the explicitly stored entries.
|
|
||||||
pub fn take_values(self) -> Vec<T> {
|
|
||||||
self.values
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Disassembles the CSR matrix into its underlying offset, index and value arrays.
|
/// Disassembles the CSR matrix into its underlying offset, index and value arrays.
|
||||||
///
|
///
|
||||||
/// If the matrix contains the sole reference to the sparsity pattern,
|
/// If the matrix contains the sole reference to the sparsity pattern,
|
||||||
|
@ -274,19 +252,7 @@ impl<T> CsrMatrix<T> {
|
||||||
/// assert_eq!(values2, values);
|
/// assert_eq!(values2, values);
|
||||||
/// ```
|
/// ```
|
||||||
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
||||||
// Take an Arc to the pattern, which might be the sole reference to the data after
|
self.cs.disassemble()
|
||||||
// taking the values. This is important, because it might let us avoid cloning the data
|
|
||||||
// further below.
|
|
||||||
let pattern = self.sparsity_pattern;
|
|
||||||
let values = self.values;
|
|
||||||
|
|
||||||
// Try to take the pattern out of the `Arc` if possible,
|
|
||||||
// otherwise clone the pattern.
|
|
||||||
let owned_pattern = Arc::try_unwrap(pattern)
|
|
||||||
.unwrap_or_else(|arc| SparsityPattern::clone(&*arc));
|
|
||||||
let (offsets, indices) = owned_pattern.disassemble();
|
|
||||||
|
|
||||||
(offsets, indices, values)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the underlying sparsity pattern.
|
/// Returns the underlying sparsity pattern.
|
||||||
|
@ -295,15 +261,14 @@ impl<T> CsrMatrix<T> {
|
||||||
/// the same sparsity pattern for multiple matrices without storing the same pattern multiple
|
/// the same sparsity pattern for multiple matrices without storing the same pattern multiple
|
||||||
/// times in memory.
|
/// times in memory.
|
||||||
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
||||||
&self.sparsity_pattern
|
self.cs.pattern()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Reinterprets the CSR matrix as its transpose represented by a CSC matrix.
|
/// Reinterprets the CSR matrix as its transpose represented by a CSC matrix.
|
||||||
///
|
///
|
||||||
/// This operation does not touch the CSR data, and is effectively a no-op.
|
/// This operation does not touch the CSR data, and is effectively a no-op.
|
||||||
pub fn transpose_as_csc(self) -> CscMatrix<T> {
|
pub fn transpose_as_csc(self) -> CscMatrix<T> {
|
||||||
let pattern = self.sparsity_pattern;
|
let (pattern, values) = self.cs.take_pattern_and_values();
|
||||||
let values = self.values;
|
|
||||||
CscMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
CscMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -312,10 +277,7 @@ impl<T> CsrMatrix<T> {
|
||||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
/// stored column entries for the given row.
|
/// stored column entries for the given row.
|
||||||
pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<T>> {
|
pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<T>> {
|
||||||
let row_range = self.get_index_range(row_index)?;
|
self.cs.get_entry(row_index, col_index)
|
||||||
let col_indices = &self.col_indices()[row_range.clone()];
|
|
||||||
let values = &self.values()[row_range];
|
|
||||||
get_entry_from_slices(self.ncols(), col_indices, values, col_index)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a mutable entry for the given row/col indices, or `None` if the indices are out
|
/// Returns a mutable entry for the given row/col indices, or `None` if the indices are out
|
||||||
|
@ -325,12 +287,7 @@ impl<T> CsrMatrix<T> {
|
||||||
/// stored column entries for the given row.
|
/// stored column entries for the given row.
|
||||||
pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize)
|
pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize)
|
||||||
-> Option<SparseEntryMut<T>> {
|
-> Option<SparseEntryMut<T>> {
|
||||||
let row_range = self.get_index_range(row_index)?;
|
self.cs.get_entry_mut(row_index, col_index)
|
||||||
let ncols = self.ncols();
|
|
||||||
let (_, col_indices, values) = self.csr_data_mut();
|
|
||||||
let col_indices = &col_indices[row_range.clone()];
|
|
||||||
let values = &mut values[row_range];
|
|
||||||
get_mut_entry_from_slices(ncols, col_indices, values, col_index)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an entry for the given row/col indices.
|
/// Returns an entry for the given row/col indices.
|
||||||
|
@ -361,14 +318,13 @@ impl<T> CsrMatrix<T> {
|
||||||
|
|
||||||
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data.
|
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data.
|
||||||
pub fn csr_data(&self) -> (&[usize], &[usize], &[T]) {
|
pub fn csr_data(&self) -> (&[usize], &[usize], &[T]) {
|
||||||
(self.row_offsets(), self.col_indices(), self.values())
|
self.cs.cs_data()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data,
|
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data,
|
||||||
/// where the `values` array is mutable.
|
/// where the `values` array is mutable.
|
||||||
pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||||
let pattern = self.sparsity_pattern.as_ref();
|
self.cs.cs_data_mut()
|
||||||
(pattern.major_offsets(), pattern.minor_indices(), &mut self.values)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -460,9 +416,7 @@ impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
|
||||||
/// An immutable representation of a row in a CSR matrix.
|
/// An immutable representation of a row in a CSR matrix.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct CsrRow<'a, T> {
|
pub struct CsrRow<'a, T> {
|
||||||
ncols: usize,
|
lane: CsLane<'a, T>
|
||||||
col_indices: &'a [usize],
|
|
||||||
values: &'a [T],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A mutable representation of a row in a CSR matrix.
|
/// A mutable representation of a row in a CSR matrix.
|
||||||
|
@ -471,9 +425,7 @@ pub struct CsrRow<'a, T> {
|
||||||
/// to the row cannot be modified.
|
/// to the row cannot be modified.
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub struct CsrRowMut<'a, T> {
|
pub struct CsrRowMut<'a, T> {
|
||||||
ncols: usize,
|
lane: CsLaneMut<'a, T>
|
||||||
col_indices: &'a [usize],
|
|
||||||
values: &'a mut [T]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implement the methods common to both CsrRow and CsrRowMut
|
/// Implement the methods common to both CsrRow and CsrRowMut
|
||||||
|
@ -483,25 +435,25 @@ macro_rules! impl_csr_row_common_methods {
|
||||||
/// The number of global columns in the row.
|
/// The number of global columns in the row.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn ncols(&self) -> usize {
|
pub fn ncols(&self) -> usize {
|
||||||
self.ncols
|
self.lane.minor_dim
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The number of non-zeros in this row.
|
/// The number of non-zeros in this row.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn nnz(&self) -> usize {
|
pub fn nnz(&self) -> usize {
|
||||||
self.col_indices.len()
|
self.lane.minor_indices.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The column indices corresponding to explicitly stored entries in this row.
|
/// The column indices corresponding to explicitly stored entries in this row.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn col_indices(&self) -> &[usize] {
|
pub fn col_indices(&self) -> &[usize] {
|
||||||
self.col_indices
|
self.lane.minor_indices
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The values corresponding to explicitly stored entries in this row.
|
/// The values corresponding to explicitly stored entries in this row.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn values(&self) -> &[T] {
|
pub fn values(&self) -> &[T] {
|
||||||
self.values
|
self.lane.values
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an entry for the given global column index.
|
/// Returns an entry for the given global column index.
|
||||||
|
@ -509,47 +461,23 @@ macro_rules! impl_csr_row_common_methods {
|
||||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
/// stored column entries.
|
/// stored column entries.
|
||||||
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
|
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
|
||||||
get_entry_from_slices(self.ncols, self.col_indices, self.values, global_col_index)
|
get_entry_from_slices(
|
||||||
|
self.lane.minor_dim,
|
||||||
|
self.lane.minor_indices,
|
||||||
|
self.lane.values,
|
||||||
|
global_col_index)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_entry_from_slices<'a, T>(ncols: usize,
|
|
||||||
col_indices: &'a [usize],
|
|
||||||
values: &'a [T],
|
|
||||||
global_col_index: usize) -> Option<SparseEntry<'a, T>> {
|
|
||||||
let local_index = col_indices.binary_search(&global_col_index);
|
|
||||||
if let Ok(local_index) = local_index {
|
|
||||||
Some(SparseEntry::NonZero(&values[local_index]))
|
|
||||||
} else if global_col_index < ncols {
|
|
||||||
Some(SparseEntry::Zero)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_mut_entry_from_slices<'a, T>(ncols: usize,
|
|
||||||
col_indices: &'a [usize],
|
|
||||||
values: &'a mut [T],
|
|
||||||
global_col_index: usize) -> Option<SparseEntryMut<'a, T>> {
|
|
||||||
let local_index = col_indices.binary_search(&global_col_index);
|
|
||||||
if let Ok(local_index) = local_index {
|
|
||||||
Some(SparseEntryMut::NonZero(&mut values[local_index]))
|
|
||||||
} else if global_col_index < ncols {
|
|
||||||
Some(SparseEntryMut::Zero)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl_csr_row_common_methods!(CsrRow<'a, T>);
|
impl_csr_row_common_methods!(CsrRow<'a, T>);
|
||||||
impl_csr_row_common_methods!(CsrRowMut<'a, T>);
|
impl_csr_row_common_methods!(CsrRowMut<'a, T>);
|
||||||
|
|
||||||
impl<'a, T> CsrRowMut<'a, T> {
|
impl<'a, T> CsrRowMut<'a, T> {
|
||||||
/// Mutable access to the values corresponding to explicitly stored entries in this row.
|
/// Mutable access to the values corresponding to explicitly stored entries in this row.
|
||||||
pub fn values_mut(&mut self) -> &mut [T] {
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
self.values
|
self.lane.values
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Provides simultaneous access to column indices and mutable values corresponding to the
|
/// Provides simultaneous access to column indices and mutable values corresponding to the
|
||||||
|
@ -558,37 +486,36 @@ impl<'a, T> CsrRowMut<'a, T> {
|
||||||
/// This method primarily facilitates low-level access for methods that process data stored
|
/// This method primarily facilitates low-level access for methods that process data stored
|
||||||
/// in CSR format directly.
|
/// in CSR format directly.
|
||||||
pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
||||||
(self.col_indices, self.values)
|
(self.lane.minor_indices, self.lane.values)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a mutable entry for the given global column index.
|
/// Returns a mutable entry for the given global column index.
|
||||||
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<T>> {
|
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<T>> {
|
||||||
get_mut_entry_from_slices(self.ncols, self.col_indices, self.values, global_col_index)
|
get_mut_entry_from_slices(self.lane.minor_dim,
|
||||||
|
self.lane.minor_indices,
|
||||||
|
self.lane.values,
|
||||||
|
global_col_index)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
||||||
pub struct CsrRowIter<'a, T> {
|
pub struct CsrRowIter<'a, T> {
|
||||||
// The index of the row that will be returned on the next
|
lane_iter: CsLaneIter<'a, T>
|
||||||
current_row_idx: usize,
|
|
||||||
matrix: &'a CsrMatrix<T>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> Iterator for CsrRowIter<'a, T> {
|
impl<'a, T> Iterator for CsrRowIter<'a, T> {
|
||||||
type Item = CsrRow<'a, T>;
|
type Item = CsrRow<'a, T>;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
let row = self.matrix.get_row(self.current_row_idx);
|
self.lane_iter
|
||||||
self.current_row_idx += 1;
|
.next()
|
||||||
row
|
.map(|lane| CsrRow { lane })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mutable row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
/// Mutable row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
||||||
pub struct CsrRowIterMut<'a, T> {
|
pub struct CsrRowIterMut<'a, T> {
|
||||||
current_row_idx: usize,
|
lane_iter: CsLaneIterMut<'a, T>
|
||||||
pattern: &'a SparsityPattern,
|
|
||||||
remaining_values: *mut T,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> Iterator for CsrRowIterMut<'a, T>
|
impl<'a, T> Iterator for CsrRowIterMut<'a, T>
|
||||||
|
@ -598,27 +525,8 @@ where
|
||||||
type Item = CsrRowMut<'a, T>;
|
type Item = CsrRowMut<'a, T>;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
let lane = self.pattern.get_lane(self.current_row_idx);
|
self.lane_iter
|
||||||
let ncols = self.pattern.minor_dim();
|
.next()
|
||||||
|
.map(|lane| CsrRowMut { lane })
|
||||||
if let Some(col_indices) = lane {
|
|
||||||
let count = col_indices.len();
|
|
||||||
|
|
||||||
// Note: I can't think of any way to construct this iterator without unsafe.
|
|
||||||
let values_in_row;
|
|
||||||
unsafe {
|
|
||||||
values_in_row = &mut *slice_from_raw_parts_mut(self.remaining_values, count);
|
|
||||||
self.remaining_values = self.remaining_values.add(count);
|
|
||||||
}
|
|
||||||
self.current_row_idx += 1;
|
|
||||||
|
|
||||||
Some(CsrRowMut {
|
|
||||||
ncols,
|
|
||||||
col_indices,
|
|
||||||
values: values_in_row
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -90,6 +90,8 @@ pub mod pattern;
|
||||||
pub mod ops;
|
pub mod ops;
|
||||||
pub mod convert;
|
pub mod convert;
|
||||||
|
|
||||||
|
mod cs;
|
||||||
|
|
||||||
#[cfg(feature = "proptest-support")]
|
#[cfg(feature = "proptest-support")]
|
||||||
pub mod proptest;
|
pub mod proptest;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue