Rename CsrMatrix::get(_mut) to get_entry(_mut) and change semantics

This commit is contained in:
Andreas Longva 2020-12-09 15:25:16 +01:00
parent 830df6d07b
commit 921686c490
2 changed files with 143 additions and 47 deletions

View File

@ -1,6 +1,5 @@
//! An implementation of the CSR sparse matrix format.
use crate::{SparseFormatError, SparseFormatErrorKind};
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::csc::CscMatrix;
@ -298,40 +297,79 @@ impl<T> CsrMatrix<T> {
pub fn pattern(&self) -> &Arc<SparsityPattern> {
&self.sparsity_pattern
}
}
impl<T: Clone + Zero> CsrMatrix<T> {
/// Return the value in the matrix at the given global row/col indices, or `None` if out of
/// bounds.
///
/// If the indices are in bounds, but no explicitly stored entry is associated with it,
/// `T::zero()` is returned. Note that this method offers no way of distinguishing
/// explicitly stored zero entries from zero values that are only implicitly represented.
///
/// Each call to this function incurs the cost of a binary search among the explicitly
/// stored column entries for the given row.
#[inline]
pub fn get(&self, row_index: usize, col_index: usize) -> Option<T> {
self.get_row(row_index)?.get(col_index)
}
/// Same as `get`, but panics if indices are out of bounds.
///
/// Panics
/// ------
/// Panics if either index is out of bounds.
#[inline]
pub fn index(&self, row_index: usize, col_index: usize) -> T {
self.get(row_index, col_index).unwrap()
}
/// Reinterprets the CSR matrix as its transpose represented by a CSC matrix.
///
/// This operation does not touch the CSR data, and is effectively a no-op.
pub fn transpose_as_csc(self) -> CscMatrix<T> {
let pattern = self.sparsity_pattern;
let values = self.values;
CscMatrix::try_from_pattern_and_values(pattern, values).unwrap()
}
/// Returns an entry for the given row/col indices, or `None` if the indices are out of bounds.
///
/// Each call to this function incurs the cost of a binary search among the explicitly
/// stored column entries for the given row.
pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<T>> {
let row_range = self.get_index_range(row_index)?;
let col_indices = &self.col_indices()[row_range.clone()];
let values = &self.values()[row_range];
get_entry_from_slices(self.ncols(), col_indices, values, col_index)
}
/// Returns a mutable entry for the given row/col indices, or `None` if the indices are out
/// of bounds.
///
/// Each call to this function incurs the cost of a binary search among the explicitly
/// stored column entries for the given row.
pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize)
-> Option<SparseEntryMut<T>> {
let row_range = self.get_index_range(row_index)?;
let ncols = self.ncols();
let (_, col_indices, values) = self.csr_data_mut();
let col_indices = &col_indices[row_range.clone()];
let values = &mut values[row_range];
get_mut_entry_from_slices(ncols, col_indices, values, col_index)
}
/// Returns an entry for the given row/col indices.
///
/// Same as `get_entry`, except that it directly panics upon encountering row/col indices
/// out of bounds.
///
/// Panics
/// ------
/// Panics if `row_index` or `col_index` is out of bounds.
pub fn index_entry(&self, row_index: usize, col_index: usize) -> SparseEntry<T> {
self.get_entry(row_index, col_index)
.expect("Out of bounds matrix indices encountered")
}
/// Returns a mutable entry for the given row/col indices.
///
/// Same as `get_entry_mut`, except that it directly panics upon encountering row/col indices
/// out of bounds.
///
/// Panics
/// ------
/// Panics if `row_index` or `col_index` is out of bounds.
pub fn index_entry_mut(&mut self, row_index: usize, col_index: usize) -> SparseEntryMut<T> {
self.get_entry_mut(row_index, col_index)
.expect("Out of bounds matrix indices encountered")
}
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data.
pub fn csr_data(&self) -> (&[usize], &[usize], &[T]) {
(self.row_offsets(), self.col_indices(), self.values())
}
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data,
/// where the `values` array is mutable.
pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
let pattern = self.sparsity_pattern.as_ref();
(pattern.major_offsets(), pattern.minor_indices(), &mut self.values)
}
}
impl<T> CsrMatrix<T>
@ -465,29 +503,43 @@ macro_rules! impl_csr_row_common_methods {
pub fn values(&self) -> &[T] {
self.values
}
}
impl<'a, T: Clone + Zero> $name {
/// Return the value in the matrix at the given global column index, or `None` if out of
/// bounds.
///
/// If the index is in bounds, but no explicitly stored entry is associated with it,
/// `T::zero()` is returned. Note that this method offers no way of distinguishing
/// explicitly stored zero entries from zero values that are only implicitly represented.
/// Returns an entry for the given global column index.
///
/// Each call to this function incurs the cost of a binary search among the explicitly
/// stored column entries for the current row.
pub fn get(&self, global_col_index: usize) -> Option<T> {
let local_index = self.col_indices().binary_search(&global_col_index);
/// stored column entries.
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
get_entry_from_slices(self.ncols, self.col_indices, self.values, global_col_index)
}
}
}
}
fn get_entry_from_slices<'a, T>(ncols: usize,
col_indices: &'a [usize],
values: &'a [T],
global_col_index: usize) -> Option<SparseEntry<'a, T>> {
let local_index = col_indices.binary_search(&global_col_index);
if let Ok(local_index) = local_index {
Some(self.values[local_index].clone())
} else if global_col_index < self.ncols {
Some(T::zero())
Some(SparseEntry::NonZero(&values[local_index]))
} else if global_col_index < ncols {
Some(SparseEntry::Zero)
} else {
None
}
}
}
}
fn get_mut_entry_from_slices<'a, T>(ncols: usize,
col_indices: &'a [usize],
values: &'a mut [T],
global_col_index: usize) -> Option<SparseEntryMut<'a, T>> {
let local_index = col_indices.binary_search(&global_col_index);
if let Ok(local_index) = local_index {
Some(SparseEntryMut::NonZero(&mut values[local_index]))
} else if global_col_index < ncols {
Some(SparseEntryMut::Zero)
} else {
None
}
}
@ -508,6 +560,11 @@ impl<'a, T> CsrRowMut<'a, T> {
pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
(self.col_indices, self.values)
}
/// Returns a mutable entry for the given global column index.
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<T>> {
get_mut_entry_from_slices(self.ncols, self.col_indices, self.values, global_col_index)
}
}
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).

View File

@ -95,6 +95,7 @@ pub mod proptest;
use std::error::Error;
use std::fmt;
use num_traits::Zero;
/// Errors produced by functions that expect well-formed sparse format data.
#[derive(Debug)]
@ -149,3 +150,41 @@ impl fmt::Display for SparseFormatError {
}
impl Error for SparseFormatError {}
/// TODO
#[derive(Debug, PartialEq, Eq)]
pub enum SparseEntry<'a, T> {
/// TODO
NonZero(&'a T),
/// TODO
Zero
}
impl<'a, T: Clone + Zero> SparseEntry<'a, T> {
/// TODO
pub fn to_value(self) -> T {
match self {
SparseEntry::NonZero(value) => value.clone(),
SparseEntry::Zero => T::zero()
}
}
}
/// TODO
#[derive(Debug, PartialEq, Eq)]
pub enum SparseEntryMut<'a, T> {
/// TODO
NonZero(&'a mut T),
/// TODO
Zero
}
impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
/// TODO
pub fn to_value(self) -> T {
match self {
SparseEntryMut::NonZero(value) => value.clone(),
SparseEntryMut::Zero => T::zero()
}
}
}