CSR row access and iterators
This commit is contained in:
parent
41425ae52c
commit
7f5b702a49
|
@ -3,12 +3,17 @@ use crate::iter::SparsityPatternIter;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::slice::{IterMut, Iter};
|
use std::slice::{IterMut, Iter};
|
||||||
|
use std::ops::Range;
|
||||||
|
use num_traits::Zero;
|
||||||
|
use std::ptr::slice_from_raw_parts_mut;
|
||||||
|
|
||||||
/// A CSR representation of a sparse matrix.
|
/// A CSR representation of a sparse matrix.
|
||||||
///
|
///
|
||||||
/// The Compressed Row Storage (CSR) format is well-suited as a general-purpose storage format
|
/// The Compressed Row Storage (CSR) format is well-suited as a general-purpose storage format
|
||||||
/// for many sparse matrix applications.
|
/// for many sparse matrix applications.
|
||||||
///
|
///
|
||||||
|
/// TODO: Storage explanation and examples
|
||||||
|
///
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct CsrMatrix<T> {
|
pub struct CsrMatrix<T> {
|
||||||
// Rows are major, cols are minor in the sparsity pattern
|
// Rows are major, cols are minor in the sparsity pattern
|
||||||
|
@ -151,8 +156,103 @@ impl<T> CsrMatrix<T> {
|
||||||
values_mut_iter: self.values.iter_mut()
|
values_mut_iter: self.values.iter_mut()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return the row at the given row index.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
/// Panics if row index is out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn row(&self, index: usize) -> CsrRow<T> {
|
||||||
|
self.get_row(index)
|
||||||
|
.expect("Row index must be in bounds")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Mutable row access for the given row index.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
/// Panics if row index is out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn row_mut(&mut self, index: usize) -> CsrRowMut<T> {
|
||||||
|
self.get_row_mut(index)
|
||||||
|
.expect("Row index must be in bounds")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the row at the given row index, or `None` if out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn get_row(&self, index: usize) -> Option<CsrRow<T>> {
|
||||||
|
let range = self.get_index_range(index)?;
|
||||||
|
Some(CsrRow {
|
||||||
|
col_indices: &self.sparsity_pattern.minor_indices()[range.clone()],
|
||||||
|
values: &self.values[range],
|
||||||
|
ncols: self.ncols()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mutable row access for the given row index, or `None` if out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<T>> {
|
||||||
|
let range = self.get_index_range(index)?;
|
||||||
|
Some(CsrRowMut {
|
||||||
|
ncols: self.ncols(),
|
||||||
|
col_indices: &self.sparsity_pattern.minor_indices()[range.clone()],
|
||||||
|
values: &mut self.values[range]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Internal method for simplifying access to a row's data.
|
||||||
|
fn get_index_range(&self, row_index: usize) -> Option<Range<usize>> {
|
||||||
|
let row_begin = *self.sparsity_pattern.major_offsets().get(row_index)?;
|
||||||
|
let row_end = *self.sparsity_pattern.major_offsets().get(row_index + 1)?;
|
||||||
|
Some(row_begin .. row_end)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An iterator over rows in the matrix.
|
||||||
|
pub fn row_iter(&self) -> CsrRowIter<T> {
|
||||||
|
CsrRowIter {
|
||||||
|
current_row_idx: 0,
|
||||||
|
matrix: self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A mutable iterator over rows in the matrix.
|
||||||
|
pub fn row_iter_mut(&mut self) -> CsrRowIterMut<T> {
|
||||||
|
CsrRowIterMut {
|
||||||
|
current_row_idx: 0,
|
||||||
|
pattern: &self.sparsity_pattern,
|
||||||
|
remaining_values: self.values.as_mut_ptr()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Clone + Zero> CsrMatrix<T> {
|
||||||
|
/// Return the value in the matrix at the given global row/col indices, or `None` if out of
|
||||||
|
/// bounds.
|
||||||
|
///
|
||||||
|
/// If the indices are in bounds, but no explicitly stored entry is associated with it,
|
||||||
|
/// `T::zero()` is returned. Note that this methods offers no way of distinguishing
|
||||||
|
/// explicitly stored zero entries from zero values that are only implicitly represented.
|
||||||
|
///
|
||||||
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
|
/// stored column entries for the given row.
|
||||||
|
#[inline]
|
||||||
|
pub fn get(&self, row_index: usize, col_index: usize) -> Option<T> {
|
||||||
|
self.get_row(row_index)?.get(col_index)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Same as `get`, but panics if indices are out of bounds.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
/// Panics if either index is out of bounds.
|
||||||
|
#[inline]
|
||||||
|
pub fn index(&self, row_index: usize, col_index: usize) -> T {
|
||||||
|
self.get(row_index, col_index).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Iterator type for iterating over triplets in a CSR matrix.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct CsrTripletIter<'a, T> {
|
pub struct CsrTripletIter<'a, T> {
|
||||||
pattern_iter: SparsityPatternIter<'a>,
|
pattern_iter: SparsityPatternIter<'a>,
|
||||||
|
@ -173,6 +273,7 @@ impl<'a, T> Iterator for CsrTripletIter<'a, T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Iterator type for mutably iterating over triplets in a CSR matrix.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct CsrTripletIterMut<'a, T> {
|
pub struct CsrTripletIterMut<'a, T> {
|
||||||
pattern_iter: SparsityPatternIter<'a>,
|
pattern_iter: SparsityPatternIter<'a>,
|
||||||
|
@ -193,3 +294,148 @@ impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// An immutable representation of a row in a CSR matrix.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct CsrRow<'a, T> {
|
||||||
|
ncols: usize,
|
||||||
|
col_indices: &'a [usize],
|
||||||
|
values: &'a [T],
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A mutable representation of a row in a CSR matrix.
|
||||||
|
///
|
||||||
|
/// Note that only explicitly stored entries can be mutated. The sparsity pattern belonging
|
||||||
|
/// to the row cannot be modified.
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub struct CsrRowMut<'a, T> {
|
||||||
|
ncols: usize,
|
||||||
|
col_indices: &'a [usize],
|
||||||
|
values: &'a mut [T]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implement the methods common to both CsrRow and CsrRowMut
|
||||||
|
macro_rules! impl_csr_row_common_methods {
|
||||||
|
($name:ty) => {
|
||||||
|
impl<'a, T> $name {
|
||||||
|
/// The number of global columns in the row.
|
||||||
|
#[inline]
|
||||||
|
pub fn ncols(&self) -> usize {
|
||||||
|
self.ncols
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of non-zeros in this row.
|
||||||
|
#[inline]
|
||||||
|
pub fn nnz(&self) -> usize {
|
||||||
|
self.col_indices.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The column indices corresponding to explicitly stored entries in this row.
|
||||||
|
#[inline]
|
||||||
|
pub fn col_indices(&self) -> &[usize] {
|
||||||
|
self.col_indices
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The values corresponding to explicitly stored entries in this row.
|
||||||
|
#[inline]
|
||||||
|
pub fn values(&self) -> &[T] {
|
||||||
|
self.values
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: Clone + Zero> $name {
|
||||||
|
/// Return the value in the matrix at the given global column index, or `None` if out of
|
||||||
|
/// bounds.
|
||||||
|
///
|
||||||
|
/// If the index is in bounds, but no explicitly stored entry is associated with it,
|
||||||
|
/// `T::zero()` is returned. Note that this methods offers no way of distinguishing
|
||||||
|
/// explicitly stored zero entries from zero values that are only implicitly represented.
|
||||||
|
///
|
||||||
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
|
/// stored column entries for the current row.
|
||||||
|
pub fn get(&self, global_col_index: usize) -> Option<T> {
|
||||||
|
let local_index = self.col_indices().binary_search(&global_col_index);
|
||||||
|
if let Ok(local_index) = local_index {
|
||||||
|
Some(self.values[local_index].clone())
|
||||||
|
} else if global_col_index < self.ncols {
|
||||||
|
Some(T::zero())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_csr_row_common_methods!(CsrRow<'a, T>);
|
||||||
|
impl_csr_row_common_methods!(CsrRowMut<'a, T>);
|
||||||
|
|
||||||
|
impl<'a, T> CsrRowMut<'a, T> {
|
||||||
|
/// Mutable access to the values corresponding to explicitly stored entries in this row.
|
||||||
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
|
self.values
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Provides simultaneous access to column indices and mutable values corresponding to the
|
||||||
|
/// explicitly stored entries in this row.
|
||||||
|
///
|
||||||
|
/// This method primarily facilitates low-level access for methods that process data stored
|
||||||
|
/// in CSR format directly.
|
||||||
|
pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
||||||
|
(self.col_indices, self.values)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CsrRowIter<'a, T> {
|
||||||
|
// The index of the row that will be returned on the next
|
||||||
|
current_row_idx: usize,
|
||||||
|
matrix: &'a CsrMatrix<T>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CsrRowIter<'a, T> {
|
||||||
|
type Item = CsrRow<'a, T>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let row = self.matrix.get_row(self.current_row_idx);
|
||||||
|
self.current_row_idx += 1;
|
||||||
|
row
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CsrRowIterMut<'a, T> {
|
||||||
|
current_row_idx: usize,
|
||||||
|
pattern: &'a SparsityPattern,
|
||||||
|
remaining_values: *mut T,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CsrRowIterMut<'a, T>
|
||||||
|
where
|
||||||
|
T: 'a
|
||||||
|
{
|
||||||
|
type Item = CsrRowMut<'a, T>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let lane = self.pattern.lane(self.current_row_idx);
|
||||||
|
let ncols = self.pattern.minor_dim();
|
||||||
|
|
||||||
|
if let Some(col_indices) = lane {
|
||||||
|
let count = col_indices.len();
|
||||||
|
|
||||||
|
// Note: I can't think of any way to construct this iterator without unsafe.
|
||||||
|
let values_in_row;
|
||||||
|
unsafe {
|
||||||
|
values_in_row = &mut *slice_from_raw_parts_mut(self.remaining_values, count);
|
||||||
|
self.remaining_values = self.remaining_values.add(count);
|
||||||
|
}
|
||||||
|
self.current_row_idx += 1;
|
||||||
|
|
||||||
|
Some(CsrRowMut {
|
||||||
|
ncols,
|
||||||
|
col_indices,
|
||||||
|
values: values_in_row
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,3 +1,13 @@
|
||||||
|
//! Sparse matrices and algorithms for nalgebra.
|
||||||
|
//!
|
||||||
|
//! TODO: Docs
|
||||||
|
#![deny(non_camel_case_types)]
|
||||||
|
#![deny(unused_parens)]
|
||||||
|
#![deny(non_upper_case_globals)]
|
||||||
|
#![deny(unused_qualifications)]
|
||||||
|
#![deny(unused_results)]
|
||||||
|
#![deny(missing_docs)]
|
||||||
|
|
||||||
mod coo;
|
mod coo;
|
||||||
mod csr;
|
mod csr;
|
||||||
mod pattern;
|
mod pattern;
|
||||||
|
@ -5,7 +15,7 @@ mod pattern;
|
||||||
pub mod ops;
|
pub mod ops;
|
||||||
|
|
||||||
pub use coo::CooMatrix;
|
pub use coo::CooMatrix;
|
||||||
pub use csr::CsrMatrix;
|
pub use csr::{CsrMatrix, CsrRow, CsrRowMut};
|
||||||
pub use pattern::{SparsityPattern};
|
pub use pattern::{SparsityPattern};
|
||||||
|
|
||||||
/// Iterator types for matrices.
|
/// Iterator types for matrices.
|
||||||
|
@ -25,6 +35,7 @@ pub mod iter {
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
/// Errors produced by functions that expect well-formed sparse format data.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
pub enum SparseFormatError {
|
pub enum SparseFormatError {
|
||||||
|
|
|
@ -105,6 +105,7 @@ impl SparsityPattern {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Iterator type for iterating over entries in a sparsity pattern.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct SparsityPatternIter<'a> {
|
pub struct SparsityPatternIter<'a> {
|
||||||
// See implementation of Iterator::next for an explanation of how these members are used
|
// See implementation of Iterator::next for an explanation of how these members are used
|
||||||
|
|
Loading…
Reference in New Issue