diff --git a/nalgebra-sparse/src/csr.rs b/nalgebra-sparse/src/csr.rs index 33348bd5..20db241b 100644 --- a/nalgebra-sparse/src/csr.rs +++ b/nalgebra-sparse/src/csr.rs @@ -1,6 +1,5 @@ //! An implementation of the CSR sparse matrix format. - -use crate::{SparseFormatError, SparseFormatErrorKind}; +use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; use crate::csc::CscMatrix; @@ -298,40 +297,79 @@ impl CsrMatrix { pub fn pattern(&self) -> &Arc { &self.sparsity_pattern } -} - -impl CsrMatrix { - /// Return the value in the matrix at the given global row/col indices, or `None` if out of - /// bounds. - /// - /// If the indices are in bounds, but no explicitly stored entry is associated with it, - /// `T::zero()` is returned. Note that this method offers no way of distinguishing - /// explicitly stored zero entries from zero values that are only implicitly represented. - /// - /// Each call to this function incurs the cost of a binary search among the explicitly - /// stored column entries for the given row. - #[inline] - pub fn get(&self, row_index: usize, col_index: usize) -> Option { - self.get_row(row_index)?.get(col_index) - } - - /// Same as `get`, but panics if indices are out of bounds. - /// - /// Panics - /// ------ - /// Panics if either index is out of bounds. - #[inline] - pub fn index(&self, row_index: usize, col_index: usize) -> T { - self.get(row_index, col_index).unwrap() - } /// Reinterprets the CSR matrix as its transpose represented by a CSC matrix. + /// /// This operation does not touch the CSR data, and is effectively a no-op. pub fn transpose_as_csc(self) -> CscMatrix { let pattern = self.sparsity_pattern; let values = self.values; CscMatrix::try_from_pattern_and_values(pattern, values).unwrap() } + + /// Returns an entry for the given row/col indices, or `None` if the indices are out of bounds. + /// + /// Each call to this function incurs the cost of a binary search among the explicitly + /// stored column entries for the given row. + pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option> { + let row_range = self.get_index_range(row_index)?; + let col_indices = &self.col_indices()[row_range.clone()]; + let values = &self.values()[row_range]; + get_entry_from_slices(self.ncols(), col_indices, values, col_index) + } + + /// Returns a mutable entry for the given row/col indices, or `None` if the indices are out + /// of bounds. + /// + /// Each call to this function incurs the cost of a binary search among the explicitly + /// stored column entries for the given row. + pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize) + -> Option> { + let row_range = self.get_index_range(row_index)?; + let ncols = self.ncols(); + let (_, col_indices, values) = self.csr_data_mut(); + let col_indices = &col_indices[row_range.clone()]; + let values = &mut values[row_range]; + get_mut_entry_from_slices(ncols, col_indices, values, col_index) + } + + /// Returns an entry for the given row/col indices. + /// + /// Same as `get_entry`, except that it directly panics upon encountering row/col indices + /// out of bounds. + /// + /// Panics + /// ------ + /// Panics if `row_index` or `col_index` is out of bounds. + pub fn index_entry(&self, row_index: usize, col_index: usize) -> SparseEntry { + self.get_entry(row_index, col_index) + .expect("Out of bounds matrix indices encountered") + } + + /// Returns a mutable entry for the given row/col indices. + /// + /// Same as `get_entry_mut`, except that it directly panics upon encountering row/col indices + /// out of bounds. + /// + /// Panics + /// ------ + /// Panics if `row_index` or `col_index` is out of bounds. + pub fn index_entry_mut(&mut self, row_index: usize, col_index: usize) -> SparseEntryMut { + self.get_entry_mut(row_index, col_index) + .expect("Out of bounds matrix indices encountered") + } + + /// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data. + pub fn csr_data(&self) -> (&[usize], &[usize], &[T]) { + (self.row_offsets(), self.col_indices(), self.values()) + } + + /// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data, + /// where the `values` array is mutable. + pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) { + let pattern = self.sparsity_pattern.as_ref(); + (pattern.major_offsets(), pattern.minor_indices(), &mut self.values) + } } impl CsrMatrix @@ -465,32 +503,46 @@ macro_rules! impl_csr_row_common_methods { pub fn values(&self) -> &[T] { self.values } - } - impl<'a, T: Clone + Zero> $name { - /// Return the value in the matrix at the given global column index, or `None` if out of - /// bounds. - /// - /// If the index is in bounds, but no explicitly stored entry is associated with it, - /// `T::zero()` is returned. Note that this method offers no way of distinguishing - /// explicitly stored zero entries from zero values that are only implicitly represented. + /// Returns an entry for the given global column index. /// /// Each call to this function incurs the cost of a binary search among the explicitly - /// stored column entries for the current row. - pub fn get(&self, global_col_index: usize) -> Option { - let local_index = self.col_indices().binary_search(&global_col_index); - if let Ok(local_index) = local_index { - Some(self.values[local_index].clone()) - } else if global_col_index < self.ncols { - Some(T::zero()) - } else { - None - } + /// stored column entries. + pub fn get_entry(&self, global_col_index: usize) -> Option> { + get_entry_from_slices(self.ncols, self.col_indices, self.values, global_col_index) } } } } +fn get_entry_from_slices<'a, T>(ncols: usize, + col_indices: &'a [usize], + values: &'a [T], + global_col_index: usize) -> Option> { + let local_index = col_indices.binary_search(&global_col_index); + if let Ok(local_index) = local_index { + Some(SparseEntry::NonZero(&values[local_index])) + } else if global_col_index < ncols { + Some(SparseEntry::Zero) + } else { + None + } +} + +fn get_mut_entry_from_slices<'a, T>(ncols: usize, + col_indices: &'a [usize], + values: &'a mut [T], + global_col_index: usize) -> Option> { + let local_index = col_indices.binary_search(&global_col_index); + if let Ok(local_index) = local_index { + Some(SparseEntryMut::NonZero(&mut values[local_index])) + } else if global_col_index < ncols { + Some(SparseEntryMut::Zero) + } else { + None + } +} + impl_csr_row_common_methods!(CsrRow<'a, T>); impl_csr_row_common_methods!(CsrRowMut<'a, T>); @@ -508,6 +560,11 @@ impl<'a, T> CsrRowMut<'a, T> { pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) { (self.col_indices, self.values) } + + /// Returns a mutable entry for the given global column index. + pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option> { + get_mut_entry_from_slices(self.ncols, self.col_indices, self.values, global_col_index) + } } /// Row iterator for [CsrMatrix](struct.CsrMatrix.html). diff --git a/nalgebra-sparse/src/lib.rs b/nalgebra-sparse/src/lib.rs index 688362f9..536f4b75 100644 --- a/nalgebra-sparse/src/lib.rs +++ b/nalgebra-sparse/src/lib.rs @@ -95,6 +95,7 @@ pub mod proptest; use std::error::Error; use std::fmt; +use num_traits::Zero; /// Errors produced by functions that expect well-formed sparse format data. #[derive(Debug)] @@ -148,4 +149,42 @@ impl fmt::Display for SparseFormatError { } } -impl Error for SparseFormatError {} \ No newline at end of file +impl Error for SparseFormatError {} + +/// TODO +#[derive(Debug, PartialEq, Eq)] +pub enum SparseEntry<'a, T> { + /// TODO + NonZero(&'a T), + /// TODO + Zero +} + +impl<'a, T: Clone + Zero> SparseEntry<'a, T> { + /// TODO + pub fn to_value(self) -> T { + match self { + SparseEntry::NonZero(value) => value.clone(), + SparseEntry::Zero => T::zero() + } + } +} + +/// TODO +#[derive(Debug, PartialEq, Eq)] +pub enum SparseEntryMut<'a, T> { + /// TODO + NonZero(&'a mut T), + /// TODO + Zero +} + +impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> { + /// TODO + pub fn to_value(self) -> T { + match self { + SparseEntryMut::NonZero(value) => value.clone(), + SparseEntryMut::Zero => T::zero() + } + } +} \ No newline at end of file