Implement Csr/CscMatrix::transpose()

This commit is contained in:
Andreas Longva 2020-12-09 14:42:31 +01:00
parent 8b7b836a37
commit 830df6d07b
7 changed files with 98 additions and 9 deletions

View File

@ -8,6 +8,8 @@ use std::slice::{IterMut, Iter};
use std::ops::Range; use std::ops::Range;
use num_traits::Zero; use num_traits::Zero;
use std::ptr::slice_from_raw_parts_mut; use std::ptr::slice_from_raw_parts_mut;
use crate::csr::CsrMatrix;
use nalgebra::Scalar;
/// A CSC representation of a sparse matrix. /// A CSC representation of a sparse matrix.
/// ///
@ -295,6 +297,15 @@ impl<T> CscMatrix<T> {
pub fn pattern(&self) -> &Arc<SparsityPattern> { pub fn pattern(&self) -> &Arc<SparsityPattern> {
&self.sparsity_pattern &self.sparsity_pattern
} }
/// Reinterprets the CSC matrix as its transpose represented by a CSR matrix.
///
/// This operation does not touch the CSC data, and is effectively a no-op.
pub fn transpose_as_csr(self) -> CsrMatrix<T> {
let pattern = self.sparsity_pattern;
let values = self.values;
CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap()
}
} }
impl<T: Clone + Zero> CscMatrix<T> { impl<T: Clone + Zero> CscMatrix<T> {
@ -323,6 +334,16 @@ impl<T: Clone + Zero> CscMatrix<T> {
} }
} }
impl<T> CscMatrix<T>
where
T: Scalar + Zero
{
/// Compute the transpose of the matrix.
pub fn transpose(&self) -> CscMatrix<T> {
CsrMatrix::from(self).transpose_as_csc()
}
}
/// Convert pattern format errors into more meaningful CSC-specific errors. /// Convert pattern format errors into more meaningful CSC-specific errors.
/// ///
/// This ensures that the terminology is consistent: we are talking about rows and columns, /// This ensures that the terminology is consistent: we are talking about rows and columns,

View File

@ -2,11 +2,14 @@
use crate::{SparseFormatError, SparseFormatErrorKind}; use crate::{SparseFormatError, SparseFormatErrorKind};
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::csc::CscMatrix;
use nalgebra::Scalar;
use num_traits::Zero;
use std::sync::Arc; use std::sync::Arc;
use std::slice::{IterMut, Iter}; use std::slice::{IterMut, Iter};
use std::ops::Range; use std::ops::Range;
use num_traits::Zero;
use std::ptr::slice_from_raw_parts_mut; use std::ptr::slice_from_raw_parts_mut;
/// A CSR representation of a sparse matrix. /// A CSR representation of a sparse matrix.
@ -321,6 +324,24 @@ impl<T: Clone + Zero> CsrMatrix<T> {
pub fn index(&self, row_index: usize, col_index: usize) -> T { pub fn index(&self, row_index: usize, col_index: usize) -> T {
self.get(row_index, col_index).unwrap() self.get(row_index, col_index).unwrap()
} }
/// Reinterprets the CSR matrix as its transpose represented by a CSC matrix.
/// This operation does not touch the CSR data, and is effectively a no-op.
pub fn transpose_as_csc(self) -> CscMatrix<T> {
let pattern = self.sparsity_pattern;
let values = self.values;
CscMatrix::try_from_pattern_and_values(pattern, values).unwrap()
}
}
impl<T> CsrMatrix<T>
where
T: Scalar + Zero
{
/// Compute the transpose of the matrix.
pub fn transpose(&self) -> CsrMatrix<T> {
CscMatrix::from(self).transpose_as_csr()
}
} }
/// Convert pattern format errors into more meaningful CSR-specific errors. /// Convert pattern format errors into more meaningful CSR-specific errors.

View File

@ -1,3 +1,8 @@
use proptest::strategy::Strategy;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::proptest::{csr, csc};
use nalgebra_sparse::csc::CscMatrix;
#[macro_export] #[macro_export]
macro_rules! assert_panics { macro_rules! assert_panics {
($e:expr) => {{ ($e:expr) => {{
@ -17,4 +22,12 @@ macro_rules! assert_panics {
panic!("assert_panics!({}) failed: the expression did not panic.", expr_string); panic!("assert_panics!({}) failed: the expression did not panic.", expr_string);
} }
}}; }};
} }
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
csr(-5 ..= 5, 0 ..= 6usize, 0 ..= 6usize, 40)
}
pub fn csc_strategy() -> impl Strategy<Value=CscMatrix<i32>> {
csc(-5 ..= 5, 0..=6usize, 0..=6usize, 40)
}

View File

@ -11,6 +11,7 @@ use proptest::prelude::*;
use nalgebra::DMatrix; use nalgebra::DMatrix;
use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::csc::CscMatrix; use nalgebra_sparse::csc::CscMatrix;
use crate::common::csc_strategy;
#[test] #[test]
fn test_convert_dense_coo() { fn test_convert_dense_coo() {
@ -276,10 +277,6 @@ fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
csr(-5 ..= 5, 0..=6usize, 0..=6usize, 40) csr(-5 ..= 5, 0..=6usize, 0..=6usize, 40)
} }
fn csc_strategy() -> impl Strategy<Value=CscMatrix<i32>> {
csc(-5 ..= 5, 0..=6usize, 0..=6usize, 40)
}
/// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns /// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns
fn non_zero_csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> { fn non_zero_csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
csr(1 ..= 5, 0..=6usize, 0..=6usize, 40) csr(1 ..= 5, 0..=6usize, 0..=6usize, 40)

View File

@ -1,5 +1,10 @@
use nalgebra_sparse::csc::CscMatrix; use nalgebra_sparse::csc::CscMatrix;
use nalgebra_sparse::SparseFormatErrorKind; use nalgebra_sparse::SparseFormatErrorKind;
use nalgebra::DMatrix;
use proptest::prelude::*;
use crate::common::csc_strategy;
#[test] #[test]
fn csc_matrix_valid_data() { fn csc_matrix_valid_data() {
@ -251,4 +256,19 @@ fn csc_matrix_get_index() {
#[test] #[test]
fn csc_matrix_col_iter() { fn csc_matrix_col_iter() {
// TODO // TODO
}
proptest! {
#[test]
fn csc_double_transpose_is_identity(csc in csc_strategy()) {
prop_assert_eq!(csc.transpose().transpose(), csc);
}
#[test]
fn csc_transpose_agrees_with_dense(csc in csc_strategy()) {
let dense_transpose = DMatrix::from(&csc).transpose();
let csc_transpose = csc.transpose();
prop_assert_eq!(dense_transpose, DMatrix::from(&csc_transpose));
prop_assert_eq!(csc.nnz(), csc_transpose.nnz());
}
} }

View File

@ -1,5 +1,8 @@
use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::SparseFormatErrorKind; use nalgebra_sparse::SparseFormatErrorKind;
use nalgebra::DMatrix;
use proptest::prelude::*;
use crate::common::csr_strategy;
#[test] #[test]
fn csr_matrix_valid_data() { fn csr_matrix_valid_data() {
@ -250,5 +253,20 @@ fn csr_matrix_get_index() {
#[test] #[test]
fn csr_matrix_row_iter() { fn csr_matrix_row_iter() {
// TODO
}
proptest! {
#[test]
fn csr_double_transpose_is_identity(csr in csr_strategy()) {
prop_assert_eq!(csr.transpose().transpose(), csr);
}
#[test]
fn csr_transpose_agrees_with_dense(csr in csr_strategy()) {
let dense_transpose = DMatrix::from(&csr).transpose();
let csr_transpose = csr.transpose();
prop_assert_eq!(dense_transpose, DMatrix::from(&csr_transpose));
prop_assert_eq!(csr.nnz(), csr_transpose.nnz());
}
} }

View File

@ -13,6 +13,8 @@ use proptest::prelude::*;
use std::panic::catch_unwind; use std::panic::catch_unwind;
use std::sync::Arc; use std::sync::Arc;
use crate::common::csr_strategy;
#[test] #[test]
fn spmv_coo_agrees_with_dense_gemv() { fn spmv_coo_agrees_with_dense_gemv() {
let x = DVector::from_column_slice(&[2, 3, 4, 5]); let x = DVector::from_column_slice(&[2, 3, 4, 5]);
@ -89,9 +91,6 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
}) })
} }
fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
csr(-5 ..= 5, 0 ..= 6usize, 0 ..= 6usize, 40)
}
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> { fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6) matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6)