Implement CsrMatrix/CscMatrix::filter and associated helpers
Includes ::lower_triangle(), ::upper_triangle() and ::diagonal_matrix().
This commit is contained in:
parent
3453577a16
commit
5869f784e5
|
@ -156,6 +156,40 @@ impl<T> CsMatrix<T> {
|
|||
pub fn lane_iter_mut(&mut self) -> CsLaneIterMut<T> {
|
||||
CsLaneIterMut::new(self.sparsity_pattern.as_ref(), &mut self.values)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn filter<P>(&self, predicate: P) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
P: Fn(usize, usize, &T) -> bool
|
||||
{
|
||||
let (major_dim, minor_dim) = (self.pattern().major_dim(), self.pattern().minor_dim());
|
||||
let mut new_offsets = Vec::with_capacity(self.pattern().major_dim() + 1);
|
||||
let mut new_indices = Vec::new();
|
||||
let mut new_values = Vec::new();
|
||||
|
||||
new_offsets.push(0);
|
||||
for (i, lane) in self.lane_iter().enumerate() {
|
||||
for (&j, value) in lane.minor_indices().iter().zip(lane.values) {
|
||||
if predicate(i, j, value) {
|
||||
new_indices.push(j);
|
||||
new_values.push(value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
new_offsets.push(new_indices.len());
|
||||
}
|
||||
|
||||
// TODO: Avoid checks here
|
||||
let new_pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||
major_dim,
|
||||
minor_dim,
|
||||
new_offsets,
|
||||
new_indices)
|
||||
.expect("Internal error: Sparsity pattern must always be valid.");
|
||||
|
||||
Self::from_pattern_and_values(Arc::new(new_pattern), new_values)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Scalar + One> CsMatrix<T> {
|
||||
|
|
|
@ -324,6 +324,46 @@ impl<T> CscMatrix<T> {
|
|||
pub fn csc_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||
self.cs.cs_data_mut()
|
||||
}
|
||||
|
||||
/// Creates a sparse matrix that contains only the explicit entries decided by the
|
||||
/// given predicate.
|
||||
pub fn filter<P>(&self, predicate: P) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
P: Fn(usize, usize, &T) -> bool
|
||||
{
|
||||
// Note: Predicate uses (row, col, value), so we have to switch around since
|
||||
// cs uses (major, minor, value)
|
||||
Self { cs: self.cs.filter(|col_idx, row_idx, v| predicate(row_idx, col_idx, v)) }
|
||||
}
|
||||
|
||||
/// Returns a new matrix representing the upper triangular part of this matrix.
|
||||
///
|
||||
/// The result includes the diagonal of the matrix.
|
||||
pub fn upper_triangle(&self) -> Self
|
||||
where
|
||||
T: Clone
|
||||
{
|
||||
self.filter(|i, j, _| i <= j)
|
||||
}
|
||||
|
||||
/// Returns a new matrix representing the lower triangular part of this matrix.
|
||||
///
|
||||
/// The result includes the diagonal of the matrix.
|
||||
pub fn lower_triangle(&self) -> Self
|
||||
where
|
||||
T: Clone
|
||||
{
|
||||
self.filter(|i, j, _| i >= j)
|
||||
}
|
||||
|
||||
/// Returns the diagonal of the matrix as a sparse matrix.
|
||||
pub fn diagonal_as_matrix(&self) -> Self
|
||||
where
|
||||
T: Clone
|
||||
{
|
||||
self.filter(|i, j, _| i == j)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> CscMatrix<T>
|
||||
|
@ -385,6 +425,17 @@ pub struct CscTripletIter<'a, T> {
|
|||
values_iter: Iter<'a, T>
|
||||
}
|
||||
|
||||
impl<'a, T: Clone> CscTripletIter<'a, T> {
|
||||
/// Adapts the triplet iterator to return owned values.
|
||||
///
|
||||
/// The triplet iterator returns references to the values. This method adapts the iterator
|
||||
/// so that the values are cloned.
|
||||
#[inline]
|
||||
pub fn cloned_values(self) -> impl 'a + Iterator<Item=(usize, usize, T)> {
|
||||
self.map(|(i, j, v)| (i, j, v.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CscTripletIter<'a, T> {
|
||||
type Item = (usize, usize, &'a T);
|
||||
|
||||
|
|
|
@ -326,6 +326,44 @@ impl<T> CsrMatrix<T> {
|
|||
pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||
self.cs.cs_data_mut()
|
||||
}
|
||||
|
||||
/// Creates a sparse matrix that contains only the explicit entries decided by the
|
||||
/// given predicate.
|
||||
pub fn filter<P>(&self, predicate: P) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
P: Fn(usize, usize, &T) -> bool
|
||||
{
|
||||
Self { cs: self.cs.filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)) }
|
||||
}
|
||||
|
||||
/// Returns a new matrix representing the upper triangular part of this matrix.
|
||||
///
|
||||
/// The result includes the diagonal of the matrix.
|
||||
pub fn upper_triangle(&self) -> Self
|
||||
where
|
||||
T: Clone
|
||||
{
|
||||
self.filter(|i, j, _| i <= j)
|
||||
}
|
||||
|
||||
/// Returns a new matrix representing the lower triangular part of this matrix.
|
||||
///
|
||||
/// The result includes the diagonal of the matrix.
|
||||
pub fn lower_triangle(&self) -> Self
|
||||
where
|
||||
T: Clone
|
||||
{
|
||||
self.filter(|i, j, _| i >= j)
|
||||
}
|
||||
|
||||
/// Returns the diagonal of the matrix as a sparse matrix.
|
||||
pub fn diagonal_as_matrix(&self) -> Self
|
||||
where
|
||||
T: Clone
|
||||
{
|
||||
self.filter(|i, j, _| i == j)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> CsrMatrix<T>
|
||||
|
@ -387,6 +425,17 @@ pub struct CsrTripletIter<'a, T> {
|
|||
values_iter: Iter<'a, T>
|
||||
}
|
||||
|
||||
impl<'a, T: Clone> CsrTripletIter<'a, T> {
|
||||
/// Adapts the triplet iterator to return owned values.
|
||||
///
|
||||
/// The triplet iterator returns references to the values. This method adapts the iterator
|
||||
/// so that the values are cloned.
|
||||
#[inline]
|
||||
pub fn cloned_values(self) -> impl 'a + Iterator<Item=(usize, usize, T)> {
|
||||
self.map(|(i, j, v)| (i, j, v.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CsrTripletIter<'a, T> {
|
||||
type Item = (usize, usize, &'a T);
|
||||
|
||||
|
|
|
@ -3,9 +3,12 @@ use nalgebra_sparse::SparseFormatErrorKind;
|
|||
use nalgebra::DMatrix;
|
||||
|
||||
use proptest::prelude::*;
|
||||
use proptest::sample::subsequence;
|
||||
|
||||
use crate::common::csc_strategy;
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[test]
|
||||
fn csc_matrix_valid_data() {
|
||||
// Construct matrix from valid data and check that selected methods return results
|
||||
|
@ -271,4 +274,58 @@ proptest! {
|
|||
prop_assert_eq!(dense_transpose, DMatrix::from(&csc_transpose));
|
||||
prop_assert_eq!(csc.nnz(), csc_transpose.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_filter(
|
||||
(csc, triplet_subset)
|
||||
in csc_strategy()
|
||||
.prop_flat_map(|matrix| {
|
||||
let triplets: Vec<_> = matrix.triplet_iter().cloned_values().collect();
|
||||
let subset = subsequence(triplets, 0 ..= matrix.nnz())
|
||||
.prop_map(|triplet_subset| {
|
||||
let set: HashSet<_> = triplet_subset.into_iter().collect();
|
||||
set
|
||||
});
|
||||
(Just(matrix), subset)
|
||||
}))
|
||||
{
|
||||
// We generate a CscMatrix and a HashSet corresponding to a subset of the (i, j, v)
|
||||
// values in the matrix, which we use for filtering the matrix entries.
|
||||
// The resulting triplets in the filtered matrix must then be exactly equal to
|
||||
// the subset.
|
||||
let filtered = csc.filter(|i, j, v| triplet_subset.contains(&(i, j, *v)));
|
||||
let filtered_triplets: HashSet<_> = filtered
|
||||
.triplet_iter()
|
||||
.cloned_values()
|
||||
.collect();
|
||||
|
||||
prop_assert_eq!(filtered_triplets, triplet_subset);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_lower_triangle_agrees_with_dense(csc in csc_strategy()) {
|
||||
let csc_lower_triangle = csc.lower_triangle();
|
||||
prop_assert_eq!(DMatrix::from(&csc_lower_triangle), DMatrix::from(&csc).lower_triangle());
|
||||
prop_assert!(csc_lower_triangle.nnz() <= csc.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_upper_triangle_agrees_with_dense(csc in csc_strategy()) {
|
||||
let csc_upper_triangle = csc.upper_triangle();
|
||||
prop_assert_eq!(DMatrix::from(&csc_upper_triangle), DMatrix::from(&csc).upper_triangle());
|
||||
prop_assert!(csc_upper_triangle.nnz() <= csc.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_diagonal_as_matrix(csc in csc_strategy()) {
|
||||
let d = csc.diagonal_as_matrix();
|
||||
let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect();
|
||||
let csc_diagonal_entries: HashSet<_> = csc
|
||||
.triplet_iter()
|
||||
.cloned_values()
|
||||
.filter(|&(i, j, _)| i == j)
|
||||
.collect();
|
||||
|
||||
prop_assert_eq!(d_entries, csc_diagonal_entries);
|
||||
}
|
||||
}
|
|
@ -1,9 +1,15 @@
|
|||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::SparseFormatErrorKind;
|
||||
use nalgebra::DMatrix;
|
||||
|
||||
use proptest::prelude::*;
|
||||
use proptest::sample::subsequence;
|
||||
|
||||
use crate::common::csr_strategy;
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
|
||||
#[test]
|
||||
fn csr_matrix_valid_data() {
|
||||
// Construct matrix from valid data and check that selected methods return results
|
||||
|
@ -269,4 +275,58 @@ proptest! {
|
|||
prop_assert_eq!(dense_transpose, DMatrix::from(&csr_transpose));
|
||||
prop_assert_eq!(csr.nnz(), csr_transpose.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_filter(
|
||||
(csr, triplet_subset)
|
||||
in csr_strategy()
|
||||
.prop_flat_map(|matrix| {
|
||||
let triplets: Vec<_> = matrix.triplet_iter().cloned_values().collect();
|
||||
let subset = subsequence(triplets, 0 ..= matrix.nnz())
|
||||
.prop_map(|triplet_subset| {
|
||||
let set: HashSet<_> = triplet_subset.into_iter().collect();
|
||||
set
|
||||
});
|
||||
(Just(matrix), subset)
|
||||
}))
|
||||
{
|
||||
// We generate a CsrMatrix and a HashSet corresponding to a subset of the (i, j, v)
|
||||
// values in the matrix, which we use for filtering the matrix entries.
|
||||
// The resulting triplets in the filtered matrix must then be exactly equal to
|
||||
// the subset.
|
||||
let filtered = csr.filter(|i, j, v| triplet_subset.contains(&(i, j, *v)));
|
||||
let filtered_triplets: HashSet<_> = filtered
|
||||
.triplet_iter()
|
||||
.cloned_values()
|
||||
.collect();
|
||||
|
||||
prop_assert_eq!(filtered_triplets, triplet_subset);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_lower_triangle_agrees_with_dense(csr in csr_strategy()) {
|
||||
let csr_lower_triangle = csr.lower_triangle();
|
||||
prop_assert_eq!(DMatrix::from(&csr_lower_triangle), DMatrix::from(&csr).lower_triangle());
|
||||
prop_assert!(csr_lower_triangle.nnz() <= csr.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_upper_triangle_agrees_with_dense(csr in csr_strategy()) {
|
||||
let csr_upper_triangle = csr.upper_triangle();
|
||||
prop_assert_eq!(DMatrix::from(&csr_upper_triangle), DMatrix::from(&csr).upper_triangle());
|
||||
prop_assert!(csr_upper_triangle.nnz() <= csr.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_diagonal_as_matrix(csr in csr_strategy()) {
|
||||
let d = csr.diagonal_as_matrix();
|
||||
let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect();
|
||||
let csr_diagonal_entries: HashSet<_> = csr
|
||||
.triplet_iter()
|
||||
.cloned_values()
|
||||
.filter(|&(i, j, _)| i == j)
|
||||
.collect();
|
||||
|
||||
prop_assert_eq!(d_entries, csr_diagonal_entries);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue