adds row and column iterators for sparse matrices
This commit is contained in:
parent
96d4d98811
commit
4788dd19c9
|
@ -1,5 +1,7 @@
|
||||||
|
use std::iter::Zip;
|
||||||
use std::mem::replace;
|
use std::mem::replace;
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
|
use std::slice::{Iter, IterMut};
|
||||||
|
|
||||||
use num_traits::One;
|
use num_traits::One;
|
||||||
|
|
||||||
|
@ -272,6 +274,54 @@ fn get_mut_entry_from_slices<'a, T>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct CsInnerIter<'a, T> {
|
||||||
|
inner_iter: Zip<Iter<'a, usize>, Iter<'a, T>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CsInnerIterMut<'a, T> {
|
||||||
|
inner_iter_mut: Zip<Iter<'a, usize>, IterMut<'a, T>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> CsInnerIter<'a, T> {
|
||||||
|
pub fn new(indices: &'a [usize], values: &'a [T]) -> Self {
|
||||||
|
Self {
|
||||||
|
inner_iter: indices.iter().zip(values.iter()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> CsInnerIterMut<'a, T> {
|
||||||
|
pub fn new(indices: &'a [usize], values: &'a mut [T]) -> Self {
|
||||||
|
Self {
|
||||||
|
inner_iter_mut: indices.iter().zip(values.iter_mut()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CsInnerIter<'a, T>
|
||||||
|
where
|
||||||
|
T: 'a,
|
||||||
|
{
|
||||||
|
type Item = (usize, &'a T);
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
self.inner_iter.next().map(|(index, value)| (*index, value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CsInnerIterMut<'a, T>
|
||||||
|
where
|
||||||
|
T: 'a,
|
||||||
|
{
|
||||||
|
type Item = (usize, &'a mut T);
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
self.inner_iter_mut
|
||||||
|
.next()
|
||||||
|
.map(|(index, value)| (*index, value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct CsLane<'a, T> {
|
pub struct CsLane<'a, T> {
|
||||||
minor_dim: usize,
|
minor_dim: usize,
|
||||||
|
@ -415,6 +465,12 @@ macro_rules! impl_cs_lane_common_methods {
|
||||||
global_col_index,
|
global_col_index,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
#[must_use]
|
||||||
|
pub fn iter(&self) -> CsInnerIter<'_, T> {
|
||||||
|
CsInnerIter::new(self.minor_indices, self.values)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -440,6 +496,12 @@ impl<'a, T> CsLaneMut<'a, T> {
|
||||||
global_minor_index,
|
global_minor_index,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
#[must_use]
|
||||||
|
pub fn iter_mut(&mut self) -> CsInnerIterMut<'_, T> {
|
||||||
|
CsInnerIterMut::new(self.minor_indices, self.values)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper struct for working with uninitialized data in vectors.
|
/// Helper struct for working with uninitialized data in vectors.
|
||||||
|
|
|
@ -7,7 +7,9 @@
|
||||||
mod csc_serde;
|
mod csc_serde;
|
||||||
|
|
||||||
use crate::cs;
|
use crate::cs;
|
||||||
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
use crate::cs::{
|
||||||
|
CsInnerIter, CsInnerIterMut, CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix,
|
||||||
|
};
|
||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||||
|
@ -717,6 +719,14 @@ macro_rules! impl_csc_col_common_methods {
|
||||||
pub fn get_entry(&self, global_row_index: usize) -> Option<SparseEntry<'_, T>> {
|
pub fn get_entry(&self, global_row_index: usize) -> Option<SparseEntry<'_, T>> {
|
||||||
self.lane.get_entry(global_row_index)
|
self.lane.get_entry(global_row_index)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Iterator over the row indices and elements of a column of a CSC matrix.
|
||||||
|
/// Equivalent to `col.row_indices().iter().zip(col.values().iter())`.
|
||||||
|
#[inline]
|
||||||
|
#[must_use]
|
||||||
|
pub fn iter(&self) -> CsInnerIter<'_, T> {
|
||||||
|
self.lane.iter()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -744,6 +754,14 @@ impl<'a, T> CscColMut<'a, T> {
|
||||||
pub fn get_entry_mut(&mut self, global_row_index: usize) -> Option<SparseEntryMut<'_, T>> {
|
pub fn get_entry_mut(&mut self, global_row_index: usize) -> Option<SparseEntryMut<'_, T>> {
|
||||||
self.lane.get_entry_mut(global_row_index)
|
self.lane.get_entry_mut(global_row_index)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Iterator over the row indices and mutable elements of a column of a CSC matrix.
|
||||||
|
/// Equivalent to `col.row_indices().iter().zip(col.values_mut().iter_mut())`.
|
||||||
|
#[inline]
|
||||||
|
#[must_use]
|
||||||
|
pub fn iter_mut(&mut self) -> CsInnerIterMut<'_, T> {
|
||||||
|
self.lane.iter_mut()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Column iterator for [CscMatrix](struct.CscMatrix.html).
|
/// Column iterator for [CscMatrix](struct.CscMatrix.html).
|
||||||
|
|
|
@ -7,7 +7,9 @@
|
||||||
mod csr_serde;
|
mod csr_serde;
|
||||||
|
|
||||||
use crate::cs;
|
use crate::cs;
|
||||||
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
use crate::cs::{
|
||||||
|
CsInnerIter, CsInnerIterMut, CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix,
|
||||||
|
};
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||||
|
@ -719,6 +721,14 @@ macro_rules! impl_csr_row_common_methods {
|
||||||
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<'_, T>> {
|
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<'_, T>> {
|
||||||
self.lane.get_entry(global_col_index)
|
self.lane.get_entry(global_col_index)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Iterator over the column indices and elements of a row of a CSR matrix.
|
||||||
|
/// Equivalent to `row.col_indices().iter().zip(row.values().iter())`.
|
||||||
|
#[inline]
|
||||||
|
#[must_use]
|
||||||
|
pub fn iter(&self) -> CsInnerIter<'_, T> {
|
||||||
|
self.lane.iter()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -749,6 +759,14 @@ impl<'a, T> CsrRowMut<'a, T> {
|
||||||
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<'_, T>> {
|
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<'_, T>> {
|
||||||
self.lane.get_entry_mut(global_col_index)
|
self.lane.get_entry_mut(global_col_index)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Iterator over the column indices and mutable elements of a row of a CSR matrix.
|
||||||
|
/// Equivalent to `row.col_indices().iter().zip(row.values_mut().iter_mut())`.
|
||||||
|
#[inline]
|
||||||
|
#[must_use]
|
||||||
|
pub fn iter_mut(&mut self) -> CsInnerIterMut<'_, T> {
|
||||||
|
self.lane.iter_mut()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
||||||
|
|
|
@ -537,6 +537,11 @@ fn csc_matrix_col_iter() {
|
||||||
assert_eq!(col.get_entry(2), Some(SparseEntry::NonZero(&2)));
|
assert_eq!(col.get_entry(2), Some(SparseEntry::NonZero(&2)));
|
||||||
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
|
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
|
||||||
assert_eq!(col.get_entry(4), None);
|
assert_eq!(col.get_entry(4), None);
|
||||||
|
|
||||||
|
let mut inner = col.iter();
|
||||||
|
assert_eq!(Some((1, &1)), inner.next());
|
||||||
|
assert_eq!(Some((2, &2)), inner.next());
|
||||||
|
assert!(inner.next().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -550,6 +555,10 @@ fn csc_matrix_col_iter() {
|
||||||
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
|
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
|
||||||
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
|
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
|
||||||
assert_eq!(col.get_entry(4), None);
|
assert_eq!(col.get_entry(4), None);
|
||||||
|
|
||||||
|
let mut inner = col.iter();
|
||||||
|
assert_eq!(Some((0, &3)), inner.next());
|
||||||
|
assert!(inner.next().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -563,6 +572,11 @@ fn csc_matrix_col_iter() {
|
||||||
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
|
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
|
||||||
assert_eq!(col.get_entry(3), Some(SparseEntry::NonZero(&5)));
|
assert_eq!(col.get_entry(3), Some(SparseEntry::NonZero(&5)));
|
||||||
assert_eq!(col.get_entry(4), None);
|
assert_eq!(col.get_entry(4), None);
|
||||||
|
|
||||||
|
let mut inner = col.iter();
|
||||||
|
assert_eq!(Some((1, &4)), inner.next());
|
||||||
|
assert_eq!(Some((3, &5)), inner.next());
|
||||||
|
assert!(inner.next().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
assert!(col_iter.next().is_none());
|
assert!(col_iter.next().is_none());
|
||||||
|
@ -595,6 +609,11 @@ fn csc_matrix_col_iter() {
|
||||||
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2)));
|
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2)));
|
||||||
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||||
assert_eq!(col.get_entry_mut(4), None);
|
assert_eq!(col.get_entry_mut(4), None);
|
||||||
|
|
||||||
|
let mut inner = col.iter_mut();
|
||||||
|
assert_eq!(Some((1, &mut 1)), inner.next());
|
||||||
|
assert_eq!(Some((2, &mut 2)), inner.next());
|
||||||
|
assert!(inner.next().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -616,6 +635,10 @@ fn csc_matrix_col_iter() {
|
||||||
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||||
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||||
assert_eq!(col.get_entry_mut(4), None);
|
assert_eq!(col.get_entry_mut(4), None);
|
||||||
|
|
||||||
|
let mut inner = col.iter_mut();
|
||||||
|
assert_eq!(Some((0, &mut 3)), inner.next());
|
||||||
|
assert!(inner.next().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -640,6 +663,11 @@ fn csc_matrix_col_iter() {
|
||||||
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||||
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5)));
|
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5)));
|
||||||
assert_eq!(col.get_entry_mut(4), None);
|
assert_eq!(col.get_entry_mut(4), None);
|
||||||
|
|
||||||
|
let mut inner = col.iter_mut();
|
||||||
|
assert_eq!(Some((1, &mut 4)), inner.next());
|
||||||
|
assert_eq!(Some((3, &mut 5)), inner.next());
|
||||||
|
assert!(inner.next().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
assert!(col_iter.next().is_none());
|
assert!(col_iter.next().is_none());
|
||||||
|
|
|
@ -531,6 +531,11 @@ fn csr_matrix_row_iter() {
|
||||||
assert_eq!(row.get_entry(2), Some(SparseEntry::NonZero(&2)));
|
assert_eq!(row.get_entry(2), Some(SparseEntry::NonZero(&2)));
|
||||||
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
|
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
|
||||||
assert_eq!(row.get_entry(4), None);
|
assert_eq!(row.get_entry(4), None);
|
||||||
|
|
||||||
|
let mut inner = row.iter();
|
||||||
|
assert_eq!(Some((1, &1)), inner.next());
|
||||||
|
assert_eq!(Some((2, &2)), inner.next());
|
||||||
|
assert!(inner.next().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -544,6 +549,10 @@ fn csr_matrix_row_iter() {
|
||||||
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
|
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
|
||||||
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
|
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
|
||||||
assert_eq!(row.get_entry(4), None);
|
assert_eq!(row.get_entry(4), None);
|
||||||
|
|
||||||
|
let mut inner = row.iter();
|
||||||
|
assert_eq!(Some((0, &3)), inner.next());
|
||||||
|
assert!(inner.next().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -557,6 +566,11 @@ fn csr_matrix_row_iter() {
|
||||||
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
|
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
|
||||||
assert_eq!(row.get_entry(3), Some(SparseEntry::NonZero(&5)));
|
assert_eq!(row.get_entry(3), Some(SparseEntry::NonZero(&5)));
|
||||||
assert_eq!(row.get_entry(4), None);
|
assert_eq!(row.get_entry(4), None);
|
||||||
|
|
||||||
|
let mut inner = row.iter();
|
||||||
|
assert_eq!(Some((1, &4)), inner.next());
|
||||||
|
assert_eq!(Some((3, &5)), inner.next());
|
||||||
|
assert!(inner.next().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
assert!(row_iter.next().is_none());
|
assert!(row_iter.next().is_none());
|
||||||
|
@ -589,6 +603,11 @@ fn csr_matrix_row_iter() {
|
||||||
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2)));
|
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2)));
|
||||||
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||||
assert_eq!(row.get_entry_mut(4), None);
|
assert_eq!(row.get_entry_mut(4), None);
|
||||||
|
|
||||||
|
let mut inner = row.iter_mut();
|
||||||
|
assert_eq!(Some((1, &mut 1)), inner.next());
|
||||||
|
assert_eq!(Some((2, &mut 2)), inner.next());
|
||||||
|
assert!(inner.next().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -610,6 +629,10 @@ fn csr_matrix_row_iter() {
|
||||||
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||||
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||||
assert_eq!(row.get_entry_mut(4), None);
|
assert_eq!(row.get_entry_mut(4), None);
|
||||||
|
|
||||||
|
let mut inner = row.iter_mut();
|
||||||
|
assert_eq!(Some((0, &mut 3)), inner.next());
|
||||||
|
assert!(inner.next().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -634,6 +657,11 @@ fn csr_matrix_row_iter() {
|
||||||
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||||
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5)));
|
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5)));
|
||||||
assert_eq!(row.get_entry_mut(4), None);
|
assert_eq!(row.get_entry_mut(4), None);
|
||||||
|
|
||||||
|
let mut inner = row.iter_mut();
|
||||||
|
assert_eq!(Some((1, &mut 4)), inner.next());
|
||||||
|
assert_eq!(Some((3, &mut 5)), inner.next());
|
||||||
|
assert!(inner.next().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
assert!(row_iter.next().is_none());
|
assert!(row_iter.next().is_none());
|
||||||
|
|
Loading…
Reference in New Issue