adds row and column iterators for sparse matrices

This commit is contained in:
Austen Nelson 2022-05-03 13:26:56 -07:00
parent 96d4d98811
commit 4788dd19c9
5 changed files with 156 additions and 2 deletions

View File

@ -1,5 +1,7 @@
use std::iter::Zip;
use std::mem::replace;
use std::ops::Range;
use std::slice::{Iter, IterMut};
use num_traits::One;
@ -272,6 +274,54 @@ fn get_mut_entry_from_slices<'a, T>(
}
}
pub struct CsInnerIter<'a, T> {
inner_iter: Zip<Iter<'a, usize>, Iter<'a, T>>,
}
pub struct CsInnerIterMut<'a, T> {
inner_iter_mut: Zip<Iter<'a, usize>, IterMut<'a, T>>,
}
impl<'a, T> CsInnerIter<'a, T> {
pub fn new(indices: &'a [usize], values: &'a [T]) -> Self {
Self {
inner_iter: indices.iter().zip(values.iter()),
}
}
}
impl<'a, T> CsInnerIterMut<'a, T> {
pub fn new(indices: &'a [usize], values: &'a mut [T]) -> Self {
Self {
inner_iter_mut: indices.iter().zip(values.iter_mut()),
}
}
}
impl<'a, T> Iterator for CsInnerIter<'a, T>
where
T: 'a,
{
type Item = (usize, &'a T);
fn next(&mut self) -> Option<Self::Item> {
self.inner_iter.next().map(|(index, value)| (*index, value))
}
}
impl<'a, T> Iterator for CsInnerIterMut<'a, T>
where
T: 'a,
{
type Item = (usize, &'a mut T);
fn next(&mut self) -> Option<Self::Item> {
self.inner_iter_mut
.next()
.map(|(index, value)| (*index, value))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsLane<'a, T> {
minor_dim: usize,
@ -415,6 +465,12 @@ macro_rules! impl_cs_lane_common_methods {
global_col_index,
)
}
#[inline]
#[must_use]
pub fn iter(&self) -> CsInnerIter<'_, T> {
CsInnerIter::new(self.minor_indices, self.values)
}
}
};
}
@ -440,6 +496,12 @@ impl<'a, T> CsLaneMut<'a, T> {
global_minor_index,
)
}
#[inline]
#[must_use]
pub fn iter_mut(&mut self) -> CsInnerIterMut<'_, T> {
CsInnerIterMut::new(self.minor_indices, self.values)
}
}
/// Helper struct for working with uninitialized data in vectors.

View File

@ -7,7 +7,9 @@
mod csc_serde;
use crate::cs;
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
use crate::cs::{
CsInnerIter, CsInnerIterMut, CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix,
};
use crate::csr::CsrMatrix;
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
@ -717,6 +719,14 @@ macro_rules! impl_csc_col_common_methods {
pub fn get_entry(&self, global_row_index: usize) -> Option<SparseEntry<'_, T>> {
self.lane.get_entry(global_row_index)
}
/// Iterator over the row indices and elements of a column of a CSC matrix.
/// Equivalent to `col.row_indices().iter().zip(col.values().iter())`.
#[inline]
#[must_use]
pub fn iter(&self) -> CsInnerIter<'_, T> {
self.lane.iter()
}
}
};
}
@ -744,6 +754,14 @@ impl<'a, T> CscColMut<'a, T> {
pub fn get_entry_mut(&mut self, global_row_index: usize) -> Option<SparseEntryMut<'_, T>> {
self.lane.get_entry_mut(global_row_index)
}
/// Iterator over the row indices and mutable elements of a column of a CSC matrix.
/// Equivalent to `col.row_indices().iter().zip(col.values_mut().iter_mut())`.
#[inline]
#[must_use]
pub fn iter_mut(&mut self) -> CsInnerIterMut<'_, T> {
self.lane.iter_mut()
}
}
/// Column iterator for [CscMatrix](struct.CscMatrix.html).

View File

@ -7,7 +7,9 @@
mod csr_serde;
use crate::cs;
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
use crate::cs::{
CsInnerIter, CsInnerIterMut, CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix,
};
use crate::csc::CscMatrix;
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
@ -719,6 +721,14 @@ macro_rules! impl_csr_row_common_methods {
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<'_, T>> {
self.lane.get_entry(global_col_index)
}
/// Iterator over the column indices and elements of a row of a CSR matrix.
/// Equivalent to `row.col_indices().iter().zip(row.values().iter())`.
#[inline]
#[must_use]
pub fn iter(&self) -> CsInnerIter<'_, T> {
self.lane.iter()
}
}
};
}
@ -749,6 +759,14 @@ impl<'a, T> CsrRowMut<'a, T> {
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<'_, T>> {
self.lane.get_entry_mut(global_col_index)
}
/// Iterator over the column indices and mutable elements of a row of a CSR matrix.
/// Equivalent to `row.col_indices().iter().zip(row.values_mut().iter_mut())`.
#[inline]
#[must_use]
pub fn iter_mut(&mut self) -> CsInnerIterMut<'_, T> {
self.lane.iter_mut()
}
}
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).

View File

@ -537,6 +537,11 @@ fn csc_matrix_col_iter() {
assert_eq!(col.get_entry(2), Some(SparseEntry::NonZero(&2)));
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(4), None);
let mut inner = col.iter();
assert_eq!(Some((1, &1)), inner.next());
assert_eq!(Some((2, &2)), inner.next());
assert!(inner.next().is_none());
}
{
@ -550,6 +555,10 @@ fn csc_matrix_col_iter() {
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(4), None);
let mut inner = col.iter();
assert_eq!(Some((0, &3)), inner.next());
assert!(inner.next().is_none());
}
{
@ -563,6 +572,11 @@ fn csc_matrix_col_iter() {
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(3), Some(SparseEntry::NonZero(&5)));
assert_eq!(col.get_entry(4), None);
let mut inner = col.iter();
assert_eq!(Some((1, &4)), inner.next());
assert_eq!(Some((3, &5)), inner.next());
assert!(inner.next().is_none());
}
assert!(col_iter.next().is_none());
@ -595,6 +609,11 @@ fn csc_matrix_col_iter() {
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2)));
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero));
assert_eq!(col.get_entry_mut(4), None);
let mut inner = col.iter_mut();
assert_eq!(Some((1, &mut 1)), inner.next());
assert_eq!(Some((2, &mut 2)), inner.next());
assert!(inner.next().is_none());
}
{
@ -616,6 +635,10 @@ fn csc_matrix_col_iter() {
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero));
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero));
assert_eq!(col.get_entry_mut(4), None);
let mut inner = col.iter_mut();
assert_eq!(Some((0, &mut 3)), inner.next());
assert!(inner.next().is_none());
}
{
@ -640,6 +663,11 @@ fn csc_matrix_col_iter() {
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero));
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5)));
assert_eq!(col.get_entry_mut(4), None);
let mut inner = col.iter_mut();
assert_eq!(Some((1, &mut 4)), inner.next());
assert_eq!(Some((3, &mut 5)), inner.next());
assert!(inner.next().is_none());
}
assert!(col_iter.next().is_none());

View File

@ -531,6 +531,11 @@ fn csr_matrix_row_iter() {
assert_eq!(row.get_entry(2), Some(SparseEntry::NonZero(&2)));
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(4), None);
let mut inner = row.iter();
assert_eq!(Some((1, &1)), inner.next());
assert_eq!(Some((2, &2)), inner.next());
assert!(inner.next().is_none());
}
{
@ -544,6 +549,10 @@ fn csr_matrix_row_iter() {
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(4), None);
let mut inner = row.iter();
assert_eq!(Some((0, &3)), inner.next());
assert!(inner.next().is_none());
}
{
@ -557,6 +566,11 @@ fn csr_matrix_row_iter() {
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(3), Some(SparseEntry::NonZero(&5)));
assert_eq!(row.get_entry(4), None);
let mut inner = row.iter();
assert_eq!(Some((1, &4)), inner.next());
assert_eq!(Some((3, &5)), inner.next());
assert!(inner.next().is_none());
}
assert!(row_iter.next().is_none());
@ -589,6 +603,11 @@ fn csr_matrix_row_iter() {
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2)));
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero));
assert_eq!(row.get_entry_mut(4), None);
let mut inner = row.iter_mut();
assert_eq!(Some((1, &mut 1)), inner.next());
assert_eq!(Some((2, &mut 2)), inner.next());
assert!(inner.next().is_none());
}
{
@ -610,6 +629,10 @@ fn csr_matrix_row_iter() {
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero));
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero));
assert_eq!(row.get_entry_mut(4), None);
let mut inner = row.iter_mut();
assert_eq!(Some((0, &mut 3)), inner.next());
assert!(inner.next().is_none());
}
{
@ -634,6 +657,11 @@ fn csr_matrix_row_iter() {
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero));
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5)));
assert_eq!(row.get_entry_mut(4), None);
let mut inner = row.iter_mut();
assert_eq!(Some((1, &mut 4)), inner.next());
assert_eq!(Some((3, &mut 5)), inner.next());
assert!(inner.next().is_none());
}
assert!(row_iter.next().is_none());