Rename CsrMatrix::get(_mut) to get_entry(_mut) and change semantics
This commit is contained in:
parent
830df6d07b
commit
921686c490
@ -1,6 +1,5 @@
|
||||
//! An implementation of the CSR sparse matrix format.
|
||||
|
||||
use crate::{SparseFormatError, SparseFormatErrorKind};
|
||||
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
|
||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||
use crate::csc::CscMatrix;
|
||||
|
||||
@ -298,40 +297,79 @@ impl<T> CsrMatrix<T> {
|
||||
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
||||
&self.sparsity_pattern
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + Zero> CsrMatrix<T> {
|
||||
/// Return the value in the matrix at the given global row/col indices, or `None` if out of
|
||||
/// bounds.
|
||||
///
|
||||
/// If the indices are in bounds, but no explicitly stored entry is associated with it,
|
||||
/// `T::zero()` is returned. Note that this method offers no way of distinguishing
|
||||
/// explicitly stored zero entries from zero values that are only implicitly represented.
|
||||
///
|
||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||
/// stored column entries for the given row.
|
||||
#[inline]
|
||||
pub fn get(&self, row_index: usize, col_index: usize) -> Option<T> {
|
||||
self.get_row(row_index)?.get(col_index)
|
||||
}
|
||||
|
||||
/// Same as `get`, but panics if indices are out of bounds.
|
||||
///
|
||||
/// Panics
|
||||
/// ------
|
||||
/// Panics if either index is out of bounds.
|
||||
#[inline]
|
||||
pub fn index(&self, row_index: usize, col_index: usize) -> T {
|
||||
self.get(row_index, col_index).unwrap()
|
||||
}
|
||||
|
||||
/// Reinterprets the CSR matrix as its transpose represented by a CSC matrix.
|
||||
///
|
||||
/// This operation does not touch the CSR data, and is effectively a no-op.
|
||||
pub fn transpose_as_csc(self) -> CscMatrix<T> {
|
||||
let pattern = self.sparsity_pattern;
|
||||
let values = self.values;
|
||||
CscMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
||||
}
|
||||
|
||||
/// Returns an entry for the given row/col indices, or `None` if the indices are out of bounds.
|
||||
///
|
||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||
/// stored column entries for the given row.
|
||||
pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<T>> {
|
||||
let row_range = self.get_index_range(row_index)?;
|
||||
let col_indices = &self.col_indices()[row_range.clone()];
|
||||
let values = &self.values()[row_range];
|
||||
get_entry_from_slices(self.ncols(), col_indices, values, col_index)
|
||||
}
|
||||
|
||||
/// Returns a mutable entry for the given row/col indices, or `None` if the indices are out
|
||||
/// of bounds.
|
||||
///
|
||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||
/// stored column entries for the given row.
|
||||
pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize)
|
||||
-> Option<SparseEntryMut<T>> {
|
||||
let row_range = self.get_index_range(row_index)?;
|
||||
let ncols = self.ncols();
|
||||
let (_, col_indices, values) = self.csr_data_mut();
|
||||
let col_indices = &col_indices[row_range.clone()];
|
||||
let values = &mut values[row_range];
|
||||
get_mut_entry_from_slices(ncols, col_indices, values, col_index)
|
||||
}
|
||||
|
||||
/// Returns an entry for the given row/col indices.
|
||||
///
|
||||
/// Same as `get_entry`, except that it directly panics upon encountering row/col indices
|
||||
/// out of bounds.
|
||||
///
|
||||
/// Panics
|
||||
/// ------
|
||||
/// Panics if `row_index` or `col_index` is out of bounds.
|
||||
pub fn index_entry(&self, row_index: usize, col_index: usize) -> SparseEntry<T> {
|
||||
self.get_entry(row_index, col_index)
|
||||
.expect("Out of bounds matrix indices encountered")
|
||||
}
|
||||
|
||||
/// Returns a mutable entry for the given row/col indices.
|
||||
///
|
||||
/// Same as `get_entry_mut`, except that it directly panics upon encountering row/col indices
|
||||
/// out of bounds.
|
||||
///
|
||||
/// Panics
|
||||
/// ------
|
||||
/// Panics if `row_index` or `col_index` is out of bounds.
|
||||
pub fn index_entry_mut(&mut self, row_index: usize, col_index: usize) -> SparseEntryMut<T> {
|
||||
self.get_entry_mut(row_index, col_index)
|
||||
.expect("Out of bounds matrix indices encountered")
|
||||
}
|
||||
|
||||
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data.
|
||||
pub fn csr_data(&self) -> (&[usize], &[usize], &[T]) {
|
||||
(self.row_offsets(), self.col_indices(), self.values())
|
||||
}
|
||||
|
||||
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data,
|
||||
/// where the `values` array is mutable.
|
||||
pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||
let pattern = self.sparsity_pattern.as_ref();
|
||||
(pattern.major_offsets(), pattern.minor_indices(), &mut self.values)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> CsrMatrix<T>
|
||||
@ -465,32 +503,46 @@ macro_rules! impl_csr_row_common_methods {
|
||||
pub fn values(&self) -> &[T] {
|
||||
self.values
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Clone + Zero> $name {
|
||||
/// Return the value in the matrix at the given global column index, or `None` if out of
|
||||
/// bounds.
|
||||
///
|
||||
/// If the index is in bounds, but no explicitly stored entry is associated with it,
|
||||
/// `T::zero()` is returned. Note that this method offers no way of distinguishing
|
||||
/// explicitly stored zero entries from zero values that are only implicitly represented.
|
||||
/// Returns an entry for the given global column index.
|
||||
///
|
||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||
/// stored column entries for the current row.
|
||||
pub fn get(&self, global_col_index: usize) -> Option<T> {
|
||||
let local_index = self.col_indices().binary_search(&global_col_index);
|
||||
if let Ok(local_index) = local_index {
|
||||
Some(self.values[local_index].clone())
|
||||
} else if global_col_index < self.ncols {
|
||||
Some(T::zero())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
/// stored column entries.
|
||||
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
|
||||
get_entry_from_slices(self.ncols, self.col_indices, self.values, global_col_index)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_entry_from_slices<'a, T>(ncols: usize,
|
||||
col_indices: &'a [usize],
|
||||
values: &'a [T],
|
||||
global_col_index: usize) -> Option<SparseEntry<'a, T>> {
|
||||
let local_index = col_indices.binary_search(&global_col_index);
|
||||
if let Ok(local_index) = local_index {
|
||||
Some(SparseEntry::NonZero(&values[local_index]))
|
||||
} else if global_col_index < ncols {
|
||||
Some(SparseEntry::Zero)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_mut_entry_from_slices<'a, T>(ncols: usize,
|
||||
col_indices: &'a [usize],
|
||||
values: &'a mut [T],
|
||||
global_col_index: usize) -> Option<SparseEntryMut<'a, T>> {
|
||||
let local_index = col_indices.binary_search(&global_col_index);
|
||||
if let Ok(local_index) = local_index {
|
||||
Some(SparseEntryMut::NonZero(&mut values[local_index]))
|
||||
} else if global_col_index < ncols {
|
||||
Some(SparseEntryMut::Zero)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl_csr_row_common_methods!(CsrRow<'a, T>);
|
||||
impl_csr_row_common_methods!(CsrRowMut<'a, T>);
|
||||
|
||||
@ -508,6 +560,11 @@ impl<'a, T> CsrRowMut<'a, T> {
|
||||
pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
||||
(self.col_indices, self.values)
|
||||
}
|
||||
|
||||
/// Returns a mutable entry for the given global column index.
|
||||
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<T>> {
|
||||
get_mut_entry_from_slices(self.ncols, self.col_indices, self.values, global_col_index)
|
||||
}
|
||||
}
|
||||
|
||||
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
||||
|
@ -95,6 +95,7 @@ pub mod proptest;
|
||||
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use num_traits::Zero;
|
||||
|
||||
/// Errors produced by functions that expect well-formed sparse format data.
|
||||
#[derive(Debug)]
|
||||
@ -148,4 +149,42 @@ impl fmt::Display for SparseFormatError {
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for SparseFormatError {}
|
||||
impl Error for SparseFormatError {}
|
||||
|
||||
/// TODO
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum SparseEntry<'a, T> {
|
||||
/// TODO
|
||||
NonZero(&'a T),
|
||||
/// TODO
|
||||
Zero
|
||||
}
|
||||
|
||||
impl<'a, T: Clone + Zero> SparseEntry<'a, T> {
|
||||
/// TODO
|
||||
pub fn to_value(self) -> T {
|
||||
match self {
|
||||
SparseEntry::NonZero(value) => value.clone(),
|
||||
SparseEntry::Zero => T::zero()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum SparseEntryMut<'a, T> {
|
||||
/// TODO
|
||||
NonZero(&'a mut T),
|
||||
/// TODO
|
||||
Zero
|
||||
}
|
||||
|
||||
impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
|
||||
/// TODO
|
||||
pub fn to_value(self) -> T {
|
||||
match self {
|
||||
SparseEntryMut::NonZero(value) => value.clone(),
|
||||
SparseEntryMut::Zero => T::zero()
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user