Implement CsrMatrix/CscMatrix::filter and associated helpers

Includes ::lower_triangle(), ::upper_triangle() and
::diagonal_matrix().
This commit is contained in:
Andreas Longva 2021-01-14 17:12:08 +01:00
parent 3453577a16
commit 5869f784e5
5 changed files with 251 additions and 0 deletions

View File

@ -156,6 +156,40 @@ impl<T> CsMatrix<T> {
pub fn lane_iter_mut(&mut self) -> CsLaneIterMut<T> { pub fn lane_iter_mut(&mut self) -> CsLaneIterMut<T> {
CsLaneIterMut::new(self.sparsity_pattern.as_ref(), &mut self.values) CsLaneIterMut::new(self.sparsity_pattern.as_ref(), &mut self.values)
} }
#[inline]
pub fn filter<P>(&self, predicate: P) -> Self
where
T: Clone,
P: Fn(usize, usize, &T) -> bool
{
let (major_dim, minor_dim) = (self.pattern().major_dim(), self.pattern().minor_dim());
let mut new_offsets = Vec::with_capacity(self.pattern().major_dim() + 1);
let mut new_indices = Vec::new();
let mut new_values = Vec::new();
new_offsets.push(0);
for (i, lane) in self.lane_iter().enumerate() {
for (&j, value) in lane.minor_indices().iter().zip(lane.values) {
if predicate(i, j, value) {
new_indices.push(j);
new_values.push(value.clone());
}
}
new_offsets.push(new_indices.len());
}
// TODO: Avoid checks here
let new_pattern = SparsityPattern::try_from_offsets_and_indices(
major_dim,
minor_dim,
new_offsets,
new_indices)
.expect("Internal error: Sparsity pattern must always be valid.");
Self::from_pattern_and_values(Arc::new(new_pattern), new_values)
}
} }
impl<T: Scalar + One> CsMatrix<T> { impl<T: Scalar + One> CsMatrix<T> {

View File

@ -324,6 +324,46 @@ impl<T> CscMatrix<T> {
pub fn csc_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) { pub fn csc_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
self.cs.cs_data_mut() self.cs.cs_data_mut()
} }
/// Creates a sparse matrix that contains only the explicit entries decided by the
/// given predicate.
pub fn filter<P>(&self, predicate: P) -> Self
where
T: Clone,
P: Fn(usize, usize, &T) -> bool
{
// Note: Predicate uses (row, col, value), so we have to switch around since
// cs uses (major, minor, value)
Self { cs: self.cs.filter(|col_idx, row_idx, v| predicate(row_idx, col_idx, v)) }
}
/// Returns a new matrix representing the upper triangular part of this matrix.
///
/// The result includes the diagonal of the matrix.
pub fn upper_triangle(&self) -> Self
where
T: Clone
{
self.filter(|i, j, _| i <= j)
}
/// Returns a new matrix representing the lower triangular part of this matrix.
///
/// The result includes the diagonal of the matrix.
pub fn lower_triangle(&self) -> Self
where
T: Clone
{
self.filter(|i, j, _| i >= j)
}
/// Returns the diagonal of the matrix as a sparse matrix.
pub fn diagonal_as_matrix(&self) -> Self
where
T: Clone
{
self.filter(|i, j, _| i == j)
}
} }
impl<T> CscMatrix<T> impl<T> CscMatrix<T>
@ -385,6 +425,17 @@ pub struct CscTripletIter<'a, T> {
values_iter: Iter<'a, T> values_iter: Iter<'a, T>
} }
impl<'a, T: Clone> CscTripletIter<'a, T> {
/// Adapts the triplet iterator to return owned values.
///
/// The triplet iterator returns references to the values. This method adapts the iterator
/// so that the values are cloned.
#[inline]
pub fn cloned_values(self) -> impl 'a + Iterator<Item=(usize, usize, T)> {
self.map(|(i, j, v)| (i, j, v.clone()))
}
}
impl<'a, T> Iterator for CscTripletIter<'a, T> { impl<'a, T> Iterator for CscTripletIter<'a, T> {
type Item = (usize, usize, &'a T); type Item = (usize, usize, &'a T);

View File

@ -326,6 +326,44 @@ impl<T> CsrMatrix<T> {
pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) { pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
self.cs.cs_data_mut() self.cs.cs_data_mut()
} }
/// Creates a sparse matrix that contains only the explicit entries decided by the
/// given predicate.
pub fn filter<P>(&self, predicate: P) -> Self
where
T: Clone,
P: Fn(usize, usize, &T) -> bool
{
Self { cs: self.cs.filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)) }
}
/// Returns a new matrix representing the upper triangular part of this matrix.
///
/// The result includes the diagonal of the matrix.
pub fn upper_triangle(&self) -> Self
where
T: Clone
{
self.filter(|i, j, _| i <= j)
}
/// Returns a new matrix representing the lower triangular part of this matrix.
///
/// The result includes the diagonal of the matrix.
pub fn lower_triangle(&self) -> Self
where
T: Clone
{
self.filter(|i, j, _| i >= j)
}
/// Returns the diagonal of the matrix as a sparse matrix.
pub fn diagonal_as_matrix(&self) -> Self
where
T: Clone
{
self.filter(|i, j, _| i == j)
}
} }
impl<T> CsrMatrix<T> impl<T> CsrMatrix<T>
@ -387,6 +425,17 @@ pub struct CsrTripletIter<'a, T> {
values_iter: Iter<'a, T> values_iter: Iter<'a, T>
} }
impl<'a, T: Clone> CsrTripletIter<'a, T> {
/// Adapts the triplet iterator to return owned values.
///
/// The triplet iterator returns references to the values. This method adapts the iterator
/// so that the values are cloned.
#[inline]
pub fn cloned_values(self) -> impl 'a + Iterator<Item=(usize, usize, T)> {
self.map(|(i, j, v)| (i, j, v.clone()))
}
}
impl<'a, T> Iterator for CsrTripletIter<'a, T> { impl<'a, T> Iterator for CsrTripletIter<'a, T> {
type Item = (usize, usize, &'a T); type Item = (usize, usize, &'a T);

View File

@ -3,9 +3,12 @@ use nalgebra_sparse::SparseFormatErrorKind;
use nalgebra::DMatrix; use nalgebra::DMatrix;
use proptest::prelude::*; use proptest::prelude::*;
use proptest::sample::subsequence;
use crate::common::csc_strategy; use crate::common::csc_strategy;
use std::collections::HashSet;
#[test] #[test]
fn csc_matrix_valid_data() { fn csc_matrix_valid_data() {
// Construct matrix from valid data and check that selected methods return results // Construct matrix from valid data and check that selected methods return results
@ -271,4 +274,58 @@ proptest! {
prop_assert_eq!(dense_transpose, DMatrix::from(&csc_transpose)); prop_assert_eq!(dense_transpose, DMatrix::from(&csc_transpose));
prop_assert_eq!(csc.nnz(), csc_transpose.nnz()); prop_assert_eq!(csc.nnz(), csc_transpose.nnz());
} }
#[test]
fn csc_filter(
(csc, triplet_subset)
in csc_strategy()
.prop_flat_map(|matrix| {
let triplets: Vec<_> = matrix.triplet_iter().cloned_values().collect();
let subset = subsequence(triplets, 0 ..= matrix.nnz())
.prop_map(|triplet_subset| {
let set: HashSet<_> = triplet_subset.into_iter().collect();
set
});
(Just(matrix), subset)
}))
{
// We generate a CscMatrix and a HashSet corresponding to a subset of the (i, j, v)
// values in the matrix, which we use for filtering the matrix entries.
// The resulting triplets in the filtered matrix must then be exactly equal to
// the subset.
let filtered = csc.filter(|i, j, v| triplet_subset.contains(&(i, j, *v)));
let filtered_triplets: HashSet<_> = filtered
.triplet_iter()
.cloned_values()
.collect();
prop_assert_eq!(filtered_triplets, triplet_subset);
}
#[test]
fn csc_lower_triangle_agrees_with_dense(csc in csc_strategy()) {
let csc_lower_triangle = csc.lower_triangle();
prop_assert_eq!(DMatrix::from(&csc_lower_triangle), DMatrix::from(&csc).lower_triangle());
prop_assert!(csc_lower_triangle.nnz() <= csc.nnz());
}
#[test]
fn csc_upper_triangle_agrees_with_dense(csc in csc_strategy()) {
let csc_upper_triangle = csc.upper_triangle();
prop_assert_eq!(DMatrix::from(&csc_upper_triangle), DMatrix::from(&csc).upper_triangle());
prop_assert!(csc_upper_triangle.nnz() <= csc.nnz());
}
#[test]
fn csc_diagonal_as_matrix(csc in csc_strategy()) {
let d = csc.diagonal_as_matrix();
let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect();
let csc_diagonal_entries: HashSet<_> = csc
.triplet_iter()
.cloned_values()
.filter(|&(i, j, _)| i == j)
.collect();
prop_assert_eq!(d_entries, csc_diagonal_entries);
}
} }

View File

@ -1,9 +1,15 @@
use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::SparseFormatErrorKind; use nalgebra_sparse::SparseFormatErrorKind;
use nalgebra::DMatrix; use nalgebra::DMatrix;
use proptest::prelude::*; use proptest::prelude::*;
use proptest::sample::subsequence;
use crate::common::csr_strategy; use crate::common::csr_strategy;
use std::collections::HashSet;
#[test] #[test]
fn csr_matrix_valid_data() { fn csr_matrix_valid_data() {
// Construct matrix from valid data and check that selected methods return results // Construct matrix from valid data and check that selected methods return results
@ -269,4 +275,58 @@ proptest! {
prop_assert_eq!(dense_transpose, DMatrix::from(&csr_transpose)); prop_assert_eq!(dense_transpose, DMatrix::from(&csr_transpose));
prop_assert_eq!(csr.nnz(), csr_transpose.nnz()); prop_assert_eq!(csr.nnz(), csr_transpose.nnz());
} }
#[test]
fn csr_filter(
(csr, triplet_subset)
in csr_strategy()
.prop_flat_map(|matrix| {
let triplets: Vec<_> = matrix.triplet_iter().cloned_values().collect();
let subset = subsequence(triplets, 0 ..= matrix.nnz())
.prop_map(|triplet_subset| {
let set: HashSet<_> = triplet_subset.into_iter().collect();
set
});
(Just(matrix), subset)
}))
{
// We generate a CsrMatrix and a HashSet corresponding to a subset of the (i, j, v)
// values in the matrix, which we use for filtering the matrix entries.
// The resulting triplets in the filtered matrix must then be exactly equal to
// the subset.
let filtered = csr.filter(|i, j, v| triplet_subset.contains(&(i, j, *v)));
let filtered_triplets: HashSet<_> = filtered
.triplet_iter()
.cloned_values()
.collect();
prop_assert_eq!(filtered_triplets, triplet_subset);
}
#[test]
fn csr_lower_triangle_agrees_with_dense(csr in csr_strategy()) {
let csr_lower_triangle = csr.lower_triangle();
prop_assert_eq!(DMatrix::from(&csr_lower_triangle), DMatrix::from(&csr).lower_triangle());
prop_assert!(csr_lower_triangle.nnz() <= csr.nnz());
}
#[test]
fn csr_upper_triangle_agrees_with_dense(csr in csr_strategy()) {
let csr_upper_triangle = csr.upper_triangle();
prop_assert_eq!(DMatrix::from(&csr_upper_triangle), DMatrix::from(&csr).upper_triangle());
prop_assert!(csr_upper_triangle.nnz() <= csr.nnz());
}
#[test]
fn csr_diagonal_as_matrix(csr in csr_strategy()) {
let d = csr.diagonal_as_matrix();
let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect();
let csr_diagonal_entries: HashSet<_> = csr
.triplet_iter()
.cloned_values()
.filter(|&(i, j, _)| i == j)
.collect();
prop_assert_eq!(d_entries, csr_diagonal_entries);
}
} }