Rename CsrMatrix::get(_mut) to get_entry(_mut) and change semantics
This commit is contained in:
parent
830df6d07b
commit
921686c490
|
@ -1,6 +1,5 @@
|
||||||
//! An implementation of the CSR sparse matrix format.
|
//! An implementation of the CSR sparse matrix format.
|
||||||
|
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
|
||||||
use crate::{SparseFormatError, SparseFormatErrorKind};
|
|
||||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
|
|
||||||
|
@ -298,40 +297,79 @@ impl<T> CsrMatrix<T> {
|
||||||
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
||||||
&self.sparsity_pattern
|
&self.sparsity_pattern
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Clone + Zero> CsrMatrix<T> {
|
|
||||||
/// Return the value in the matrix at the given global row/col indices, or `None` if out of
|
|
||||||
/// bounds.
|
|
||||||
///
|
|
||||||
/// If the indices are in bounds, but no explicitly stored entry is associated with it,
|
|
||||||
/// `T::zero()` is returned. Note that this method offers no way of distinguishing
|
|
||||||
/// explicitly stored zero entries from zero values that are only implicitly represented.
|
|
||||||
///
|
|
||||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
|
||||||
/// stored column entries for the given row.
|
|
||||||
#[inline]
|
|
||||||
pub fn get(&self, row_index: usize, col_index: usize) -> Option<T> {
|
|
||||||
self.get_row(row_index)?.get(col_index)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Same as `get`, but panics if indices are out of bounds.
|
|
||||||
///
|
|
||||||
/// Panics
|
|
||||||
/// ------
|
|
||||||
/// Panics if either index is out of bounds.
|
|
||||||
#[inline]
|
|
||||||
pub fn index(&self, row_index: usize, col_index: usize) -> T {
|
|
||||||
self.get(row_index, col_index).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Reinterprets the CSR matrix as its transpose represented by a CSC matrix.
|
/// Reinterprets the CSR matrix as its transpose represented by a CSC matrix.
|
||||||
|
///
|
||||||
/// This operation does not touch the CSR data, and is effectively a no-op.
|
/// This operation does not touch the CSR data, and is effectively a no-op.
|
||||||
pub fn transpose_as_csc(self) -> CscMatrix<T> {
|
pub fn transpose_as_csc(self) -> CscMatrix<T> {
|
||||||
let pattern = self.sparsity_pattern;
|
let pattern = self.sparsity_pattern;
|
||||||
let values = self.values;
|
let values = self.values;
|
||||||
CscMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
CscMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns an entry for the given row/col indices, or `None` if the indices are out of bounds.
|
||||||
|
///
|
||||||
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
|
/// stored column entries for the given row.
|
||||||
|
pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<T>> {
|
||||||
|
let row_range = self.get_index_range(row_index)?;
|
||||||
|
let col_indices = &self.col_indices()[row_range.clone()];
|
||||||
|
let values = &self.values()[row_range];
|
||||||
|
get_entry_from_slices(self.ncols(), col_indices, values, col_index)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable entry for the given row/col indices, or `None` if the indices are out
|
||||||
|
/// of bounds.
|
||||||
|
///
|
||||||
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
|
/// stored column entries for the given row.
|
||||||
|
pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize)
|
||||||
|
-> Option<SparseEntryMut<T>> {
|
||||||
|
let row_range = self.get_index_range(row_index)?;
|
||||||
|
let ncols = self.ncols();
|
||||||
|
let (_, col_indices, values) = self.csr_data_mut();
|
||||||
|
let col_indices = &col_indices[row_range.clone()];
|
||||||
|
let values = &mut values[row_range];
|
||||||
|
get_mut_entry_from_slices(ncols, col_indices, values, col_index)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an entry for the given row/col indices.
|
||||||
|
///
|
||||||
|
/// Same as `get_entry`, except that it directly panics upon encountering row/col indices
|
||||||
|
/// out of bounds.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
/// Panics if `row_index` or `col_index` is out of bounds.
|
||||||
|
pub fn index_entry(&self, row_index: usize, col_index: usize) -> SparseEntry<T> {
|
||||||
|
self.get_entry(row_index, col_index)
|
||||||
|
.expect("Out of bounds matrix indices encountered")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable entry for the given row/col indices.
|
||||||
|
///
|
||||||
|
/// Same as `get_entry_mut`, except that it directly panics upon encountering row/col indices
|
||||||
|
/// out of bounds.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
/// Panics if `row_index` or `col_index` is out of bounds.
|
||||||
|
pub fn index_entry_mut(&mut self, row_index: usize, col_index: usize) -> SparseEntryMut<T> {
|
||||||
|
self.get_entry_mut(row_index, col_index)
|
||||||
|
.expect("Out of bounds matrix indices encountered")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data.
|
||||||
|
pub fn csr_data(&self) -> (&[usize], &[usize], &[T]) {
|
||||||
|
(self.row_offsets(), self.col_indices(), self.values())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data,
|
||||||
|
/// where the `values` array is mutable.
|
||||||
|
pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||||
|
let pattern = self.sparsity_pattern.as_ref();
|
||||||
|
(pattern.major_offsets(), pattern.minor_indices(), &mut self.values)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> CsrMatrix<T>
|
impl<T> CsrMatrix<T>
|
||||||
|
@ -465,32 +503,46 @@ macro_rules! impl_csr_row_common_methods {
|
||||||
pub fn values(&self) -> &[T] {
|
pub fn values(&self) -> &[T] {
|
||||||
self.values
|
self.values
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, T: Clone + Zero> $name {
|
/// Returns an entry for the given global column index.
|
||||||
/// Return the value in the matrix at the given global column index, or `None` if out of
|
|
||||||
/// bounds.
|
|
||||||
///
|
|
||||||
/// If the index is in bounds, but no explicitly stored entry is associated with it,
|
|
||||||
/// `T::zero()` is returned. Note that this method offers no way of distinguishing
|
|
||||||
/// explicitly stored zero entries from zero values that are only implicitly represented.
|
|
||||||
///
|
///
|
||||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
/// stored column entries for the current row.
|
/// stored column entries.
|
||||||
pub fn get(&self, global_col_index: usize) -> Option<T> {
|
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
|
||||||
let local_index = self.col_indices().binary_search(&global_col_index);
|
get_entry_from_slices(self.ncols, self.col_indices, self.values, global_col_index)
|
||||||
if let Ok(local_index) = local_index {
|
|
||||||
Some(self.values[local_index].clone())
|
|
||||||
} else if global_col_index < self.ncols {
|
|
||||||
Some(T::zero())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_entry_from_slices<'a, T>(ncols: usize,
|
||||||
|
col_indices: &'a [usize],
|
||||||
|
values: &'a [T],
|
||||||
|
global_col_index: usize) -> Option<SparseEntry<'a, T>> {
|
||||||
|
let local_index = col_indices.binary_search(&global_col_index);
|
||||||
|
if let Ok(local_index) = local_index {
|
||||||
|
Some(SparseEntry::NonZero(&values[local_index]))
|
||||||
|
} else if global_col_index < ncols {
|
||||||
|
Some(SparseEntry::Zero)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_mut_entry_from_slices<'a, T>(ncols: usize,
|
||||||
|
col_indices: &'a [usize],
|
||||||
|
values: &'a mut [T],
|
||||||
|
global_col_index: usize) -> Option<SparseEntryMut<'a, T>> {
|
||||||
|
let local_index = col_indices.binary_search(&global_col_index);
|
||||||
|
if let Ok(local_index) = local_index {
|
||||||
|
Some(SparseEntryMut::NonZero(&mut values[local_index]))
|
||||||
|
} else if global_col_index < ncols {
|
||||||
|
Some(SparseEntryMut::Zero)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl_csr_row_common_methods!(CsrRow<'a, T>);
|
impl_csr_row_common_methods!(CsrRow<'a, T>);
|
||||||
impl_csr_row_common_methods!(CsrRowMut<'a, T>);
|
impl_csr_row_common_methods!(CsrRowMut<'a, T>);
|
||||||
|
|
||||||
|
@ -508,6 +560,11 @@ impl<'a, T> CsrRowMut<'a, T> {
|
||||||
pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
||||||
(self.col_indices, self.values)
|
(self.col_indices, self.values)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable entry for the given global column index.
|
||||||
|
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<T>> {
|
||||||
|
get_mut_entry_from_slices(self.ncols, self.col_indices, self.values, global_col_index)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
||||||
|
|
|
@ -95,6 +95,7 @@ pub mod proptest;
|
||||||
|
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
use num_traits::Zero;
|
||||||
|
|
||||||
/// Errors produced by functions that expect well-formed sparse format data.
|
/// Errors produced by functions that expect well-formed sparse format data.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -148,4 +149,42 @@ impl fmt::Display for SparseFormatError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Error for SparseFormatError {}
|
impl Error for SparseFormatError {}
|
||||||
|
|
||||||
|
/// TODO
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub enum SparseEntry<'a, T> {
|
||||||
|
/// TODO
|
||||||
|
NonZero(&'a T),
|
||||||
|
/// TODO
|
||||||
|
Zero
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: Clone + Zero> SparseEntry<'a, T> {
|
||||||
|
/// TODO
|
||||||
|
pub fn to_value(self) -> T {
|
||||||
|
match self {
|
||||||
|
SparseEntry::NonZero(value) => value.clone(),
|
||||||
|
SparseEntry::Zero => T::zero()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub enum SparseEntryMut<'a, T> {
|
||||||
|
/// TODO
|
||||||
|
NonZero(&'a mut T),
|
||||||
|
/// TODO
|
||||||
|
Zero
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
|
||||||
|
/// TODO
|
||||||
|
pub fn to_value(self) -> T {
|
||||||
|
match self {
|
||||||
|
SparseEntryMut::NonZero(value) => value.clone(),
|
||||||
|
SparseEntryMut::Zero => T::zero()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue