adds row and column iterators for sparse matrices
This commit is contained in:
parent
96d4d98811
commit
4788dd19c9
|
@ -1,5 +1,7 @@
|
|||
use std::iter::Zip;
|
||||
use std::mem::replace;
|
||||
use std::ops::Range;
|
||||
use std::slice::{Iter, IterMut};
|
||||
|
||||
use num_traits::One;
|
||||
|
||||
|
@ -272,6 +274,54 @@ fn get_mut_entry_from_slices<'a, T>(
|
|||
}
|
||||
}
|
||||
|
||||
pub struct CsInnerIter<'a, T> {
|
||||
inner_iter: Zip<Iter<'a, usize>, Iter<'a, T>>,
|
||||
}
|
||||
|
||||
pub struct CsInnerIterMut<'a, T> {
|
||||
inner_iter_mut: Zip<Iter<'a, usize>, IterMut<'a, T>>,
|
||||
}
|
||||
|
||||
impl<'a, T> CsInnerIter<'a, T> {
|
||||
pub fn new(indices: &'a [usize], values: &'a [T]) -> Self {
|
||||
Self {
|
||||
inner_iter: indices.iter().zip(values.iter()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> CsInnerIterMut<'a, T> {
|
||||
pub fn new(indices: &'a [usize], values: &'a mut [T]) -> Self {
|
||||
Self {
|
||||
inner_iter_mut: indices.iter().zip(values.iter_mut()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CsInnerIter<'a, T>
|
||||
where
|
||||
T: 'a,
|
||||
{
|
||||
type Item = (usize, &'a T);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.inner_iter.next().map(|(index, value)| (*index, value))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CsInnerIterMut<'a, T>
|
||||
where
|
||||
T: 'a,
|
||||
{
|
||||
type Item = (usize, &'a mut T);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.inner_iter_mut
|
||||
.next()
|
||||
.map(|(index, value)| (*index, value))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CsLane<'a, T> {
|
||||
minor_dim: usize,
|
||||
|
@ -415,6 +465,12 @@ macro_rules! impl_cs_lane_common_methods {
|
|||
global_col_index,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn iter(&self) -> CsInnerIter<'_, T> {
|
||||
CsInnerIter::new(self.minor_indices, self.values)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -440,6 +496,12 @@ impl<'a, T> CsLaneMut<'a, T> {
|
|||
global_minor_index,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn iter_mut(&mut self) -> CsInnerIterMut<'_, T> {
|
||||
CsInnerIterMut::new(self.minor_indices, self.values)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper struct for working with uninitialized data in vectors.
|
||||
|
|
|
@ -7,7 +7,9 @@
|
|||
mod csc_serde;
|
||||
|
||||
use crate::cs;
|
||||
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||
use crate::cs::{
|
||||
CsInnerIter, CsInnerIterMut, CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix,
|
||||
};
|
||||
use crate::csr::CsrMatrix;
|
||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||
|
@ -717,6 +719,14 @@ macro_rules! impl_csc_col_common_methods {
|
|||
pub fn get_entry(&self, global_row_index: usize) -> Option<SparseEntry<'_, T>> {
|
||||
self.lane.get_entry(global_row_index)
|
||||
}
|
||||
|
||||
/// Iterator over the row indices and elements of a column of a CSC matrix.
|
||||
/// Equivalent to `col.row_indices().iter().zip(col.values().iter())`.
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn iter(&self) -> CsInnerIter<'_, T> {
|
||||
self.lane.iter()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -744,6 +754,14 @@ impl<'a, T> CscColMut<'a, T> {
|
|||
pub fn get_entry_mut(&mut self, global_row_index: usize) -> Option<SparseEntryMut<'_, T>> {
|
||||
self.lane.get_entry_mut(global_row_index)
|
||||
}
|
||||
|
||||
/// Iterator over the row indices and mutable elements of a column of a CSC matrix.
|
||||
/// Equivalent to `col.row_indices().iter().zip(col.values_mut().iter_mut())`.
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn iter_mut(&mut self) -> CsInnerIterMut<'_, T> {
|
||||
self.lane.iter_mut()
|
||||
}
|
||||
}
|
||||
|
||||
/// Column iterator for [CscMatrix](struct.CscMatrix.html).
|
||||
|
|
|
@ -7,7 +7,9 @@
|
|||
mod csr_serde;
|
||||
|
||||
use crate::cs;
|
||||
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||
use crate::cs::{
|
||||
CsInnerIter, CsInnerIterMut, CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix,
|
||||
};
|
||||
use crate::csc::CscMatrix;
|
||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||
|
@ -719,6 +721,14 @@ macro_rules! impl_csr_row_common_methods {
|
|||
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<'_, T>> {
|
||||
self.lane.get_entry(global_col_index)
|
||||
}
|
||||
|
||||
/// Iterator over the column indices and elements of a row of a CSR matrix.
|
||||
/// Equivalent to `row.col_indices().iter().zip(row.values().iter())`.
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn iter(&self) -> CsInnerIter<'_, T> {
|
||||
self.lane.iter()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -749,6 +759,14 @@ impl<'a, T> CsrRowMut<'a, T> {
|
|||
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<'_, T>> {
|
||||
self.lane.get_entry_mut(global_col_index)
|
||||
}
|
||||
|
||||
/// Iterator over the column indices and mutable elements of a row of a CSR matrix.
|
||||
/// Equivalent to `row.col_indices().iter().zip(row.values_mut().iter_mut())`.
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn iter_mut(&mut self) -> CsInnerIterMut<'_, T> {
|
||||
self.lane.iter_mut()
|
||||
}
|
||||
}
|
||||
|
||||
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
||||
|
|
|
@ -537,6 +537,11 @@ fn csc_matrix_col_iter() {
|
|||
assert_eq!(col.get_entry(2), Some(SparseEntry::NonZero(&2)));
|
||||
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(4), None);
|
||||
|
||||
let mut inner = col.iter();
|
||||
assert_eq!(Some((1, &1)), inner.next());
|
||||
assert_eq!(Some((2, &2)), inner.next());
|
||||
assert!(inner.next().is_none());
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -550,6 +555,10 @@ fn csc_matrix_col_iter() {
|
|||
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(4), None);
|
||||
|
||||
let mut inner = col.iter();
|
||||
assert_eq!(Some((0, &3)), inner.next());
|
||||
assert!(inner.next().is_none());
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -563,6 +572,11 @@ fn csc_matrix_col_iter() {
|
|||
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(3), Some(SparseEntry::NonZero(&5)));
|
||||
assert_eq!(col.get_entry(4), None);
|
||||
|
||||
let mut inner = col.iter();
|
||||
assert_eq!(Some((1, &4)), inner.next());
|
||||
assert_eq!(Some((3, &5)), inner.next());
|
||||
assert!(inner.next().is_none());
|
||||
}
|
||||
|
||||
assert!(col_iter.next().is_none());
|
||||
|
@ -595,6 +609,11 @@ fn csc_matrix_col_iter() {
|
|||
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2)));
|
||||
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(col.get_entry_mut(4), None);
|
||||
|
||||
let mut inner = col.iter_mut();
|
||||
assert_eq!(Some((1, &mut 1)), inner.next());
|
||||
assert_eq!(Some((2, &mut 2)), inner.next());
|
||||
assert!(inner.next().is_none());
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -616,6 +635,10 @@ fn csc_matrix_col_iter() {
|
|||
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(col.get_entry_mut(4), None);
|
||||
|
||||
let mut inner = col.iter_mut();
|
||||
assert_eq!(Some((0, &mut 3)), inner.next());
|
||||
assert!(inner.next().is_none());
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -640,6 +663,11 @@ fn csc_matrix_col_iter() {
|
|||
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5)));
|
||||
assert_eq!(col.get_entry_mut(4), None);
|
||||
|
||||
let mut inner = col.iter_mut();
|
||||
assert_eq!(Some((1, &mut 4)), inner.next());
|
||||
assert_eq!(Some((3, &mut 5)), inner.next());
|
||||
assert!(inner.next().is_none());
|
||||
}
|
||||
|
||||
assert!(col_iter.next().is_none());
|
||||
|
|
|
@ -531,6 +531,11 @@ fn csr_matrix_row_iter() {
|
|||
assert_eq!(row.get_entry(2), Some(SparseEntry::NonZero(&2)));
|
||||
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(4), None);
|
||||
|
||||
let mut inner = row.iter();
|
||||
assert_eq!(Some((1, &1)), inner.next());
|
||||
assert_eq!(Some((2, &2)), inner.next());
|
||||
assert!(inner.next().is_none());
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -544,6 +549,10 @@ fn csr_matrix_row_iter() {
|
|||
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(4), None);
|
||||
|
||||
let mut inner = row.iter();
|
||||
assert_eq!(Some((0, &3)), inner.next());
|
||||
assert!(inner.next().is_none());
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -557,6 +566,11 @@ fn csr_matrix_row_iter() {
|
|||
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(3), Some(SparseEntry::NonZero(&5)));
|
||||
assert_eq!(row.get_entry(4), None);
|
||||
|
||||
let mut inner = row.iter();
|
||||
assert_eq!(Some((1, &4)), inner.next());
|
||||
assert_eq!(Some((3, &5)), inner.next());
|
||||
assert!(inner.next().is_none());
|
||||
}
|
||||
|
||||
assert!(row_iter.next().is_none());
|
||||
|
@ -589,6 +603,11 @@ fn csr_matrix_row_iter() {
|
|||
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2)));
|
||||
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(row.get_entry_mut(4), None);
|
||||
|
||||
let mut inner = row.iter_mut();
|
||||
assert_eq!(Some((1, &mut 1)), inner.next());
|
||||
assert_eq!(Some((2, &mut 2)), inner.next());
|
||||
assert!(inner.next().is_none());
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -610,6 +629,10 @@ fn csr_matrix_row_iter() {
|
|||
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(row.get_entry_mut(4), None);
|
||||
|
||||
let mut inner = row.iter_mut();
|
||||
assert_eq!(Some((0, &mut 3)), inner.next());
|
||||
assert!(inner.next().is_none());
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -634,6 +657,11 @@ fn csr_matrix_row_iter() {
|
|||
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5)));
|
||||
assert_eq!(row.get_entry_mut(4), None);
|
||||
|
||||
let mut inner = row.iter_mut();
|
||||
assert_eq!(Some((1, &mut 4)), inner.next());
|
||||
assert_eq!(Some((3, &mut 5)), inner.next());
|
||||
assert!(inner.next().is_none());
|
||||
}
|
||||
|
||||
assert!(row_iter.next().is_none());
|
||||
|
|
Loading…
Reference in New Issue