From 7f5b702a49bfa08603f8a592d03f778df3c8c6c5 Mon Sep 17 00:00:00 2001 From: Andreas Longva Date: Tue, 21 Jul 2020 17:39:06 +0200 Subject: [PATCH] CSR row access and iterators --- nalgebra-sparse/src/csr.rs | 246 +++++++++++++++++++++++++++++++++ nalgebra-sparse/src/lib.rs | 13 +- nalgebra-sparse/src/pattern.rs | 1 + 3 files changed, 259 insertions(+), 1 deletion(-) diff --git a/nalgebra-sparse/src/csr.rs b/nalgebra-sparse/src/csr.rs index 6fab9212..5c526fb1 100644 --- a/nalgebra-sparse/src/csr.rs +++ b/nalgebra-sparse/src/csr.rs @@ -3,12 +3,17 @@ use crate::iter::SparsityPatternIter; use std::sync::Arc; use std::slice::{IterMut, Iter}; +use std::ops::Range; +use num_traits::Zero; +use std::ptr::slice_from_raw_parts_mut; /// A CSR representation of a sparse matrix. /// /// The Compressed Row Storage (CSR) format is well-suited as a general-purpose storage format /// for many sparse matrix applications. /// +/// TODO: Storage explanation and examples +/// #[derive(Debug, Clone, PartialEq, Eq)] pub struct CsrMatrix { // Rows are major, cols are minor in the sparsity pattern @@ -151,8 +156,103 @@ impl CsrMatrix { values_mut_iter: self.values.iter_mut() } } + + /// Return the row at the given row index. + /// + /// Panics + /// ------ + /// Panics if row index is out of bounds. + #[inline] + pub fn row(&self, index: usize) -> CsrRow { + self.get_row(index) + .expect("Row index must be in bounds") + } + + /// Mutable row access for the given row index. + /// + /// Panics + /// ------ + /// Panics if row index is out of bounds. + #[inline] + pub fn row_mut(&mut self, index: usize) -> CsrRowMut { + self.get_row_mut(index) + .expect("Row index must be in bounds") + } + + /// Return the row at the given row index, or `None` if out of bounds. + #[inline] + pub fn get_row(&self, index: usize) -> Option> { + let range = self.get_index_range(index)?; + Some(CsrRow { + col_indices: &self.sparsity_pattern.minor_indices()[range.clone()], + values: &self.values[range], + ncols: self.ncols() + }) + } + + /// Mutable row access for the given row index, or `None` if out of bounds. + #[inline] + pub fn get_row_mut(&mut self, index: usize) -> Option> { + let range = self.get_index_range(index)?; + Some(CsrRowMut { + ncols: self.ncols(), + col_indices: &self.sparsity_pattern.minor_indices()[range.clone()], + values: &mut self.values[range] + }) + } + + /// Internal method for simplifying access to a row's data. + fn get_index_range(&self, row_index: usize) -> Option> { + let row_begin = *self.sparsity_pattern.major_offsets().get(row_index)?; + let row_end = *self.sparsity_pattern.major_offsets().get(row_index + 1)?; + Some(row_begin .. row_end) + } + + /// An iterator over rows in the matrix. + pub fn row_iter(&self) -> CsrRowIter { + CsrRowIter { + current_row_idx: 0, + matrix: self + } + } + + /// A mutable iterator over rows in the matrix. + pub fn row_iter_mut(&mut self) -> CsrRowIterMut { + CsrRowIterMut { + current_row_idx: 0, + pattern: &self.sparsity_pattern, + remaining_values: self.values.as_mut_ptr() + } + } } +impl CsrMatrix { + /// Return the value in the matrix at the given global row/col indices, or `None` if out of + /// bounds. + /// + /// If the indices are in bounds, but no explicitly stored entry is associated with it, + /// `T::zero()` is returned. Note that this methods offers no way of distinguishing + /// explicitly stored zero entries from zero values that are only implicitly represented. + /// + /// Each call to this function incurs the cost of a binary search among the explicitly + /// stored column entries for the given row. + #[inline] + pub fn get(&self, row_index: usize, col_index: usize) -> Option { + self.get_row(row_index)?.get(col_index) + } + + /// Same as `get`, but panics if indices are out of bounds. + /// + /// Panics + /// ------ + /// Panics if either index is out of bounds. + #[inline] + pub fn index(&self, row_index: usize, col_index: usize) -> T { + self.get(row_index, col_index).unwrap() + } +} + +/// Iterator type for iterating over triplets in a CSR matrix. #[derive(Debug)] pub struct CsrTripletIter<'a, T> { pattern_iter: SparsityPatternIter<'a>, @@ -173,6 +273,7 @@ impl<'a, T> Iterator for CsrTripletIter<'a, T> { } } +/// Iterator type for mutably iterating over triplets in a CSR matrix. #[derive(Debug)] pub struct CsrTripletIterMut<'a, T> { pattern_iter: SparsityPatternIter<'a>, @@ -192,4 +293,149 @@ impl<'a, T> Iterator for CsrTripletIterMut<'a, T> { _ => None } } +} + +/// An immutable representation of a row in a CSR matrix. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CsrRow<'a, T> { + ncols: usize, + col_indices: &'a [usize], + values: &'a [T], +} + +/// A mutable representation of a row in a CSR matrix. +/// +/// Note that only explicitly stored entries can be mutated. The sparsity pattern belonging +/// to the row cannot be modified. +#[derive(Debug, PartialEq, Eq)] +pub struct CsrRowMut<'a, T> { + ncols: usize, + col_indices: &'a [usize], + values: &'a mut [T] +} + +/// Implement the methods common to both CsrRow and CsrRowMut +macro_rules! impl_csr_row_common_methods { + ($name:ty) => { + impl<'a, T> $name { + /// The number of global columns in the row. + #[inline] + pub fn ncols(&self) -> usize { + self.ncols + } + + /// The number of non-zeros in this row. + #[inline] + pub fn nnz(&self) -> usize { + self.col_indices.len() + } + + /// The column indices corresponding to explicitly stored entries in this row. + #[inline] + pub fn col_indices(&self) -> &[usize] { + self.col_indices + } + + /// The values corresponding to explicitly stored entries in this row. + #[inline] + pub fn values(&self) -> &[T] { + self.values + } + } + + impl<'a, T: Clone + Zero> $name { + /// Return the value in the matrix at the given global column index, or `None` if out of + /// bounds. + /// + /// If the index is in bounds, but no explicitly stored entry is associated with it, + /// `T::zero()` is returned. Note that this methods offers no way of distinguishing + /// explicitly stored zero entries from zero values that are only implicitly represented. + /// + /// Each call to this function incurs the cost of a binary search among the explicitly + /// stored column entries for the current row. + pub fn get(&self, global_col_index: usize) -> Option { + let local_index = self.col_indices().binary_search(&global_col_index); + if let Ok(local_index) = local_index { + Some(self.values[local_index].clone()) + } else if global_col_index < self.ncols { + Some(T::zero()) + } else { + None + } + } + } + } +} + +impl_csr_row_common_methods!(CsrRow<'a, T>); +impl_csr_row_common_methods!(CsrRowMut<'a, T>); + +impl<'a, T> CsrRowMut<'a, T> { + /// Mutable access to the values corresponding to explicitly stored entries in this row. + pub fn values_mut(&mut self) -> &mut [T] { + self.values + } + + /// Provides simultaneous access to column indices and mutable values corresponding to the + /// explicitly stored entries in this row. + /// + /// This method primarily facilitates low-level access for methods that process data stored + /// in CSR format directly. + pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) { + (self.col_indices, self.values) + } +} + +pub struct CsrRowIter<'a, T> { + // The index of the row that will be returned on the next + current_row_idx: usize, + matrix: &'a CsrMatrix +} + +impl<'a, T> Iterator for CsrRowIter<'a, T> { + type Item = CsrRow<'a, T>; + + fn next(&mut self) -> Option { + let row = self.matrix.get_row(self.current_row_idx); + self.current_row_idx += 1; + row + } +} + +pub struct CsrRowIterMut<'a, T> { + current_row_idx: usize, + pattern: &'a SparsityPattern, + remaining_values: *mut T, +} + +impl<'a, T> Iterator for CsrRowIterMut<'a, T> +where + T: 'a +{ + type Item = CsrRowMut<'a, T>; + + fn next(&mut self) -> Option { + let lane = self.pattern.lane(self.current_row_idx); + let ncols = self.pattern.minor_dim(); + + if let Some(col_indices) = lane { + let count = col_indices.len(); + + // Note: I can't think of any way to construct this iterator without unsafe. + let values_in_row; + unsafe { + values_in_row = &mut *slice_from_raw_parts_mut(self.remaining_values, count); + self.remaining_values = self.remaining_values.add(count); + } + self.current_row_idx += 1; + + Some(CsrRowMut { + ncols, + col_indices, + values: values_in_row + }) + } else { + None + } + } } \ No newline at end of file diff --git a/nalgebra-sparse/src/lib.rs b/nalgebra-sparse/src/lib.rs index 9967e94a..ce77de62 100644 --- a/nalgebra-sparse/src/lib.rs +++ b/nalgebra-sparse/src/lib.rs @@ -1,3 +1,13 @@ +//! Sparse matrices and algorithms for nalgebra. +//! +//! TODO: Docs +#![deny(non_camel_case_types)] +#![deny(unused_parens)] +#![deny(non_upper_case_globals)] +#![deny(unused_qualifications)] +#![deny(unused_results)] +#![deny(missing_docs)] + mod coo; mod csr; mod pattern; @@ -5,7 +15,7 @@ mod pattern; pub mod ops; pub use coo::CooMatrix; -pub use csr::CsrMatrix; +pub use csr::{CsrMatrix, CsrRow, CsrRowMut}; pub use pattern::{SparsityPattern}; /// Iterator types for matrices. @@ -25,6 +35,7 @@ pub mod iter { use std::error::Error; use std::fmt; +/// Errors produced by functions that expect well-formed sparse format data. #[derive(Debug)] #[non_exhaustive] pub enum SparseFormatError { diff --git a/nalgebra-sparse/src/pattern.rs b/nalgebra-sparse/src/pattern.rs index 91e21902..c64d7054 100644 --- a/nalgebra-sparse/src/pattern.rs +++ b/nalgebra-sparse/src/pattern.rs @@ -105,6 +105,7 @@ impl SparsityPattern { } } +/// Iterator type for iterating over entries in a sparsity pattern. #[derive(Debug, Clone)] pub struct SparsityPatternIter<'a> { // See implementation of Iterator::next for an explanation of how these members are used