diff --git a/nalgebra-sparse/src/cs.rs b/nalgebra-sparse/src/cs.rs index c1b4b449..634ce413 100644 --- a/nalgebra-sparse/src/cs.rs +++ b/nalgebra-sparse/src/cs.rs @@ -156,6 +156,40 @@ impl CsMatrix { pub fn lane_iter_mut(&mut self) -> CsLaneIterMut { CsLaneIterMut::new(self.sparsity_pattern.as_ref(), &mut self.values) } + + #[inline] + pub fn filter

(&self, predicate: P) -> Self + where + T: Clone, + P: Fn(usize, usize, &T) -> bool + { + let (major_dim, minor_dim) = (self.pattern().major_dim(), self.pattern().minor_dim()); + let mut new_offsets = Vec::with_capacity(self.pattern().major_dim() + 1); + let mut new_indices = Vec::new(); + let mut new_values = Vec::new(); + + new_offsets.push(0); + for (i, lane) in self.lane_iter().enumerate() { + for (&j, value) in lane.minor_indices().iter().zip(lane.values) { + if predicate(i, j, value) { + new_indices.push(j); + new_values.push(value.clone()); + } + } + + new_offsets.push(new_indices.len()); + } + + // TODO: Avoid checks here + let new_pattern = SparsityPattern::try_from_offsets_and_indices( + major_dim, + minor_dim, + new_offsets, + new_indices) + .expect("Internal error: Sparsity pattern must always be valid."); + + Self::from_pattern_and_values(Arc::new(new_pattern), new_values) + } } impl CsMatrix { diff --git a/nalgebra-sparse/src/csc.rs b/nalgebra-sparse/src/csc.rs index 11b96a6f..1d8b8970 100644 --- a/nalgebra-sparse/src/csc.rs +++ b/nalgebra-sparse/src/csc.rs @@ -324,6 +324,46 @@ impl CscMatrix { pub fn csc_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) { self.cs.cs_data_mut() } + + /// Creates a sparse matrix that contains only the explicit entries decided by the + /// given predicate. + pub fn filter

(&self, predicate: P) -> Self + where + T: Clone, + P: Fn(usize, usize, &T) -> bool + { + // Note: Predicate uses (row, col, value), so we have to switch around since + // cs uses (major, minor, value) + Self { cs: self.cs.filter(|col_idx, row_idx, v| predicate(row_idx, col_idx, v)) } + } + + /// Returns a new matrix representing the upper triangular part of this matrix. + /// + /// The result includes the diagonal of the matrix. + pub fn upper_triangle(&self) -> Self + where + T: Clone + { + self.filter(|i, j, _| i <= j) + } + + /// Returns a new matrix representing the lower triangular part of this matrix. + /// + /// The result includes the diagonal of the matrix. + pub fn lower_triangle(&self) -> Self + where + T: Clone + { + self.filter(|i, j, _| i >= j) + } + + /// Returns the diagonal of the matrix as a sparse matrix. + pub fn diagonal_as_matrix(&self) -> Self + where + T: Clone + { + self.filter(|i, j, _| i == j) + } } impl CscMatrix @@ -385,6 +425,17 @@ pub struct CscTripletIter<'a, T> { values_iter: Iter<'a, T> } +impl<'a, T: Clone> CscTripletIter<'a, T> { + /// Adapts the triplet iterator to return owned values. + /// + /// The triplet iterator returns references to the values. This method adapts the iterator + /// so that the values are cloned. + #[inline] + pub fn cloned_values(self) -> impl 'a + Iterator { + self.map(|(i, j, v)| (i, j, v.clone())) + } +} + impl<'a, T> Iterator for CscTripletIter<'a, T> { type Item = (usize, usize, &'a T); diff --git a/nalgebra-sparse/src/csr.rs b/nalgebra-sparse/src/csr.rs index ca42492c..1f621c86 100644 --- a/nalgebra-sparse/src/csr.rs +++ b/nalgebra-sparse/src/csr.rs @@ -326,6 +326,44 @@ impl CsrMatrix { pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) { self.cs.cs_data_mut() } + + /// Creates a sparse matrix that contains only the explicit entries decided by the + /// given predicate. + pub fn filter

(&self, predicate: P) -> Self + where + T: Clone, + P: Fn(usize, usize, &T) -> bool + { + Self { cs: self.cs.filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)) } + } + + /// Returns a new matrix representing the upper triangular part of this matrix. + /// + /// The result includes the diagonal of the matrix. + pub fn upper_triangle(&self) -> Self + where + T: Clone + { + self.filter(|i, j, _| i <= j) + } + + /// Returns a new matrix representing the lower triangular part of this matrix. + /// + /// The result includes the diagonal of the matrix. + pub fn lower_triangle(&self) -> Self + where + T: Clone + { + self.filter(|i, j, _| i >= j) + } + + /// Returns the diagonal of the matrix as a sparse matrix. + pub fn diagonal_as_matrix(&self) -> Self + where + T: Clone + { + self.filter(|i, j, _| i == j) + } } impl CsrMatrix @@ -387,6 +425,17 @@ pub struct CsrTripletIter<'a, T> { values_iter: Iter<'a, T> } +impl<'a, T: Clone> CsrTripletIter<'a, T> { + /// Adapts the triplet iterator to return owned values. + /// + /// The triplet iterator returns references to the values. This method adapts the iterator + /// so that the values are cloned. + #[inline] + pub fn cloned_values(self) -> impl 'a + Iterator { + self.map(|(i, j, v)| (i, j, v.clone())) + } +} + impl<'a, T> Iterator for CsrTripletIter<'a, T> { type Item = (usize, usize, &'a T); diff --git a/nalgebra-sparse/tests/unit_tests/csc.rs b/nalgebra-sparse/tests/unit_tests/csc.rs index a16e1686..4faa4e12 100644 --- a/nalgebra-sparse/tests/unit_tests/csc.rs +++ b/nalgebra-sparse/tests/unit_tests/csc.rs @@ -3,9 +3,12 @@ use nalgebra_sparse::SparseFormatErrorKind; use nalgebra::DMatrix; use proptest::prelude::*; +use proptest::sample::subsequence; use crate::common::csc_strategy; +use std::collections::HashSet; + #[test] fn csc_matrix_valid_data() { // Construct matrix from valid data and check that selected methods return results @@ -271,4 +274,58 @@ proptest! { prop_assert_eq!(dense_transpose, DMatrix::from(&csc_transpose)); prop_assert_eq!(csc.nnz(), csc_transpose.nnz()); } + + #[test] + fn csc_filter( + (csc, triplet_subset) + in csc_strategy() + .prop_flat_map(|matrix| { + let triplets: Vec<_> = matrix.triplet_iter().cloned_values().collect(); + let subset = subsequence(triplets, 0 ..= matrix.nnz()) + .prop_map(|triplet_subset| { + let set: HashSet<_> = triplet_subset.into_iter().collect(); + set + }); + (Just(matrix), subset) + })) + { + // We generate a CscMatrix and a HashSet corresponding to a subset of the (i, j, v) + // values in the matrix, which we use for filtering the matrix entries. + // The resulting triplets in the filtered matrix must then be exactly equal to + // the subset. + let filtered = csc.filter(|i, j, v| triplet_subset.contains(&(i, j, *v))); + let filtered_triplets: HashSet<_> = filtered + .triplet_iter() + .cloned_values() + .collect(); + + prop_assert_eq!(filtered_triplets, triplet_subset); + } + + #[test] + fn csc_lower_triangle_agrees_with_dense(csc in csc_strategy()) { + let csc_lower_triangle = csc.lower_triangle(); + prop_assert_eq!(DMatrix::from(&csc_lower_triangle), DMatrix::from(&csc).lower_triangle()); + prop_assert!(csc_lower_triangle.nnz() <= csc.nnz()); + } + + #[test] + fn csc_upper_triangle_agrees_with_dense(csc in csc_strategy()) { + let csc_upper_triangle = csc.upper_triangle(); + prop_assert_eq!(DMatrix::from(&csc_upper_triangle), DMatrix::from(&csc).upper_triangle()); + prop_assert!(csc_upper_triangle.nnz() <= csc.nnz()); + } + + #[test] + fn csc_diagonal_as_matrix(csc in csc_strategy()) { + let d = csc.diagonal_as_matrix(); + let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect(); + let csc_diagonal_entries: HashSet<_> = csc + .triplet_iter() + .cloned_values() + .filter(|&(i, j, _)| i == j) + .collect(); + + prop_assert_eq!(d_entries, csc_diagonal_entries); + } } \ No newline at end of file diff --git a/nalgebra-sparse/tests/unit_tests/csr.rs b/nalgebra-sparse/tests/unit_tests/csr.rs index 424bc2c1..4885d25b 100644 --- a/nalgebra-sparse/tests/unit_tests/csr.rs +++ b/nalgebra-sparse/tests/unit_tests/csr.rs @@ -1,9 +1,15 @@ use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::SparseFormatErrorKind; use nalgebra::DMatrix; + use proptest::prelude::*; +use proptest::sample::subsequence; + use crate::common::csr_strategy; +use std::collections::HashSet; + + #[test] fn csr_matrix_valid_data() { // Construct matrix from valid data and check that selected methods return results @@ -269,4 +275,58 @@ proptest! { prop_assert_eq!(dense_transpose, DMatrix::from(&csr_transpose)); prop_assert_eq!(csr.nnz(), csr_transpose.nnz()); } + + #[test] + fn csr_filter( + (csr, triplet_subset) + in csr_strategy() + .prop_flat_map(|matrix| { + let triplets: Vec<_> = matrix.triplet_iter().cloned_values().collect(); + let subset = subsequence(triplets, 0 ..= matrix.nnz()) + .prop_map(|triplet_subset| { + let set: HashSet<_> = triplet_subset.into_iter().collect(); + set + }); + (Just(matrix), subset) + })) + { + // We generate a CsrMatrix and a HashSet corresponding to a subset of the (i, j, v) + // values in the matrix, which we use for filtering the matrix entries. + // The resulting triplets in the filtered matrix must then be exactly equal to + // the subset. + let filtered = csr.filter(|i, j, v| triplet_subset.contains(&(i, j, *v))); + let filtered_triplets: HashSet<_> = filtered + .triplet_iter() + .cloned_values() + .collect(); + + prop_assert_eq!(filtered_triplets, triplet_subset); + } + + #[test] + fn csr_lower_triangle_agrees_with_dense(csr in csr_strategy()) { + let csr_lower_triangle = csr.lower_triangle(); + prop_assert_eq!(DMatrix::from(&csr_lower_triangle), DMatrix::from(&csr).lower_triangle()); + prop_assert!(csr_lower_triangle.nnz() <= csr.nnz()); + } + + #[test] + fn csr_upper_triangle_agrees_with_dense(csr in csr_strategy()) { + let csr_upper_triangle = csr.upper_triangle(); + prop_assert_eq!(DMatrix::from(&csr_upper_triangle), DMatrix::from(&csr).upper_triangle()); + prop_assert!(csr_upper_triangle.nnz() <= csr.nnz()); + } + + #[test] + fn csr_diagonal_as_matrix(csr in csr_strategy()) { + let d = csr.diagonal_as_matrix(); + let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect(); + let csr_diagonal_entries: HashSet<_> = csr + .triplet_iter() + .cloned_values() + .filter(|&(i, j, _)| i == j) + .collect(); + + prop_assert_eq!(d_entries, csr_diagonal_entries); + } } \ No newline at end of file