From 4788dd19c98f26f709293e87788ec7b37d61cac1 Mon Sep 17 00:00:00 2001 From: Austen Nelson Date: Tue, 3 May 2022 13:26:56 -0700 Subject: [PATCH] adds row and column iterators for sparse matrices --- nalgebra-sparse/src/cs.rs | 62 +++++++++++++++++++++++++ nalgebra-sparse/src/csc.rs | 20 +++++++- nalgebra-sparse/src/csr.rs | 20 +++++++- nalgebra-sparse/tests/unit_tests/csc.rs | 28 +++++++++++ nalgebra-sparse/tests/unit_tests/csr.rs | 28 +++++++++++ 5 files changed, 156 insertions(+), 2 deletions(-) diff --git a/nalgebra-sparse/src/cs.rs b/nalgebra-sparse/src/cs.rs index 474eb2c0..9d0eeb5c 100644 --- a/nalgebra-sparse/src/cs.rs +++ b/nalgebra-sparse/src/cs.rs @@ -1,5 +1,7 @@ +use std::iter::Zip; use std::mem::replace; use std::ops::Range; +use std::slice::{Iter, IterMut}; use num_traits::One; @@ -272,6 +274,54 @@ fn get_mut_entry_from_slices<'a, T>( } } +pub struct CsInnerIter<'a, T> { + inner_iter: Zip, Iter<'a, T>>, +} + +pub struct CsInnerIterMut<'a, T> { + inner_iter_mut: Zip, IterMut<'a, T>>, +} + +impl<'a, T> CsInnerIter<'a, T> { + pub fn new(indices: &'a [usize], values: &'a [T]) -> Self { + Self { + inner_iter: indices.iter().zip(values.iter()), + } + } +} + +impl<'a, T> CsInnerIterMut<'a, T> { + pub fn new(indices: &'a [usize], values: &'a mut [T]) -> Self { + Self { + inner_iter_mut: indices.iter().zip(values.iter_mut()), + } + } +} + +impl<'a, T> Iterator for CsInnerIter<'a, T> +where + T: 'a, +{ + type Item = (usize, &'a T); + + fn next(&mut self) -> Option { + self.inner_iter.next().map(|(index, value)| (*index, value)) + } +} + +impl<'a, T> Iterator for CsInnerIterMut<'a, T> +where + T: 'a, +{ + type Item = (usize, &'a mut T); + + fn next(&mut self) -> Option { + self.inner_iter_mut + .next() + .map(|(index, value)| (*index, value)) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct CsLane<'a, T> { minor_dim: usize, @@ -415,6 +465,12 @@ macro_rules! impl_cs_lane_common_methods { global_col_index, ) } + + #[inline] + #[must_use] + pub fn iter(&self) -> CsInnerIter<'_, T> { + CsInnerIter::new(self.minor_indices, self.values) + } } }; } @@ -440,6 +496,12 @@ impl<'a, T> CsLaneMut<'a, T> { global_minor_index, ) } + + #[inline] + #[must_use] + pub fn iter_mut(&mut self) -> CsInnerIterMut<'_, T> { + CsInnerIterMut::new(self.minor_indices, self.values) + } } /// Helper struct for working with uninitialized data in vectors. diff --git a/nalgebra-sparse/src/csc.rs b/nalgebra-sparse/src/csc.rs index d926dafb..3bcfddcb 100644 --- a/nalgebra-sparse/src/csc.rs +++ b/nalgebra-sparse/src/csc.rs @@ -7,7 +7,9 @@ mod csc_serde; use crate::cs; -use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix}; +use crate::cs::{ + CsInnerIter, CsInnerIterMut, CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix, +}; use crate::csr::CsrMatrix; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind}; @@ -717,6 +719,14 @@ macro_rules! impl_csc_col_common_methods { pub fn get_entry(&self, global_row_index: usize) -> Option> { self.lane.get_entry(global_row_index) } + + /// Iterator over the row indices and elements of a column of a CSC matrix. + /// Equivalent to `col.row_indices().iter().zip(col.values().iter())`. + #[inline] + #[must_use] + pub fn iter(&self) -> CsInnerIter<'_, T> { + self.lane.iter() + } } }; } @@ -744,6 +754,14 @@ impl<'a, T> CscColMut<'a, T> { pub fn get_entry_mut(&mut self, global_row_index: usize) -> Option> { self.lane.get_entry_mut(global_row_index) } + + /// Iterator over the row indices and mutable elements of a column of a CSC matrix. + /// Equivalent to `col.row_indices().iter().zip(col.values_mut().iter_mut())`. + #[inline] + #[must_use] + pub fn iter_mut(&mut self) -> CsInnerIterMut<'_, T> { + self.lane.iter_mut() + } } /// Column iterator for [CscMatrix](struct.CscMatrix.html). diff --git a/nalgebra-sparse/src/csr.rs b/nalgebra-sparse/src/csr.rs index 90be35f1..4d9c147c 100644 --- a/nalgebra-sparse/src/csr.rs +++ b/nalgebra-sparse/src/csr.rs @@ -7,7 +7,9 @@ mod csr_serde; use crate::cs; -use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix}; +use crate::cs::{ + CsInnerIter, CsInnerIterMut, CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix, +}; use crate::csc::CscMatrix; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind}; @@ -719,6 +721,14 @@ macro_rules! impl_csr_row_common_methods { pub fn get_entry(&self, global_col_index: usize) -> Option> { self.lane.get_entry(global_col_index) } + + /// Iterator over the column indices and elements of a row of a CSR matrix. + /// Equivalent to `row.col_indices().iter().zip(row.values().iter())`. + #[inline] + #[must_use] + pub fn iter(&self) -> CsInnerIter<'_, T> { + self.lane.iter() + } } }; } @@ -749,6 +759,14 @@ impl<'a, T> CsrRowMut<'a, T> { pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option> { self.lane.get_entry_mut(global_col_index) } + + /// Iterator over the column indices and mutable elements of a row of a CSR matrix. + /// Equivalent to `row.col_indices().iter().zip(row.values_mut().iter_mut())`. + #[inline] + #[must_use] + pub fn iter_mut(&mut self) -> CsInnerIterMut<'_, T> { + self.lane.iter_mut() + } } /// Row iterator for [CsrMatrix](struct.CsrMatrix.html). diff --git a/nalgebra-sparse/tests/unit_tests/csc.rs b/nalgebra-sparse/tests/unit_tests/csc.rs index 1554b8a6..9b968d8e 100644 --- a/nalgebra-sparse/tests/unit_tests/csc.rs +++ b/nalgebra-sparse/tests/unit_tests/csc.rs @@ -537,6 +537,11 @@ fn csc_matrix_col_iter() { assert_eq!(col.get_entry(2), Some(SparseEntry::NonZero(&2))); assert_eq!(col.get_entry(3), Some(SparseEntry::Zero)); assert_eq!(col.get_entry(4), None); + + let mut inner = col.iter(); + assert_eq!(Some((1, &1)), inner.next()); + assert_eq!(Some((2, &2)), inner.next()); + assert!(inner.next().is_none()); } { @@ -550,6 +555,10 @@ fn csc_matrix_col_iter() { assert_eq!(col.get_entry(2), Some(SparseEntry::Zero)); assert_eq!(col.get_entry(3), Some(SparseEntry::Zero)); assert_eq!(col.get_entry(4), None); + + let mut inner = col.iter(); + assert_eq!(Some((0, &3)), inner.next()); + assert!(inner.next().is_none()); } { @@ -563,6 +572,11 @@ fn csc_matrix_col_iter() { assert_eq!(col.get_entry(2), Some(SparseEntry::Zero)); assert_eq!(col.get_entry(3), Some(SparseEntry::NonZero(&5))); assert_eq!(col.get_entry(4), None); + + let mut inner = col.iter(); + assert_eq!(Some((1, &4)), inner.next()); + assert_eq!(Some((3, &5)), inner.next()); + assert!(inner.next().is_none()); } assert!(col_iter.next().is_none()); @@ -595,6 +609,11 @@ fn csc_matrix_col_iter() { assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2))); assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero)); assert_eq!(col.get_entry_mut(4), None); + + let mut inner = col.iter_mut(); + assert_eq!(Some((1, &mut 1)), inner.next()); + assert_eq!(Some((2, &mut 2)), inner.next()); + assert!(inner.next().is_none()); } { @@ -616,6 +635,10 @@ fn csc_matrix_col_iter() { assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero)); assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero)); assert_eq!(col.get_entry_mut(4), None); + + let mut inner = col.iter_mut(); + assert_eq!(Some((0, &mut 3)), inner.next()); + assert!(inner.next().is_none()); } { @@ -640,6 +663,11 @@ fn csc_matrix_col_iter() { assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero)); assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5))); assert_eq!(col.get_entry_mut(4), None); + + let mut inner = col.iter_mut(); + assert_eq!(Some((1, &mut 4)), inner.next()); + assert_eq!(Some((3, &mut 5)), inner.next()); + assert!(inner.next().is_none()); } assert!(col_iter.next().is_none()); diff --git a/nalgebra-sparse/tests/unit_tests/csr.rs b/nalgebra-sparse/tests/unit_tests/csr.rs index a00470d5..c6d0d023 100644 --- a/nalgebra-sparse/tests/unit_tests/csr.rs +++ b/nalgebra-sparse/tests/unit_tests/csr.rs @@ -531,6 +531,11 @@ fn csr_matrix_row_iter() { assert_eq!(row.get_entry(2), Some(SparseEntry::NonZero(&2))); assert_eq!(row.get_entry(3), Some(SparseEntry::Zero)); assert_eq!(row.get_entry(4), None); + + let mut inner = row.iter(); + assert_eq!(Some((1, &1)), inner.next()); + assert_eq!(Some((2, &2)), inner.next()); + assert!(inner.next().is_none()); } { @@ -544,6 +549,10 @@ fn csr_matrix_row_iter() { assert_eq!(row.get_entry(2), Some(SparseEntry::Zero)); assert_eq!(row.get_entry(3), Some(SparseEntry::Zero)); assert_eq!(row.get_entry(4), None); + + let mut inner = row.iter(); + assert_eq!(Some((0, &3)), inner.next()); + assert!(inner.next().is_none()); } { @@ -557,6 +566,11 @@ fn csr_matrix_row_iter() { assert_eq!(row.get_entry(2), Some(SparseEntry::Zero)); assert_eq!(row.get_entry(3), Some(SparseEntry::NonZero(&5))); assert_eq!(row.get_entry(4), None); + + let mut inner = row.iter(); + assert_eq!(Some((1, &4)), inner.next()); + assert_eq!(Some((3, &5)), inner.next()); + assert!(inner.next().is_none()); } assert!(row_iter.next().is_none()); @@ -589,6 +603,11 @@ fn csr_matrix_row_iter() { assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2))); assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero)); assert_eq!(row.get_entry_mut(4), None); + + let mut inner = row.iter_mut(); + assert_eq!(Some((1, &mut 1)), inner.next()); + assert_eq!(Some((2, &mut 2)), inner.next()); + assert!(inner.next().is_none()); } { @@ -610,6 +629,10 @@ fn csr_matrix_row_iter() { assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero)); assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero)); assert_eq!(row.get_entry_mut(4), None); + + let mut inner = row.iter_mut(); + assert_eq!(Some((0, &mut 3)), inner.next()); + assert!(inner.next().is_none()); } { @@ -634,6 +657,11 @@ fn csr_matrix_row_iter() { assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero)); assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5))); assert_eq!(row.get_entry_mut(4), None); + + let mut inner = row.iter_mut(); + assert_eq!(Some((1, &mut 4)), inner.next()); + assert_eq!(Some((3, &mut 5)), inner.next()); + assert!(inner.next().is_none()); } assert!(row_iter.next().is_none());