From 830df6d07b81ac24e52fb1c83f24ed80213871ab Mon Sep 17 00:00:00 2001 From: Andreas Longva Date: Wed, 9 Dec 2020 14:42:31 +0100 Subject: [PATCH] Implement Csr/CscMatrix::transpose() --- nalgebra-sparse/src/csc.rs | 21 +++++++++++++++++ nalgebra-sparse/src/csr.rs | 23 ++++++++++++++++++- nalgebra-sparse/tests/common/mod.rs | 15 +++++++++++- .../tests/unit_tests/convert_serial.rs | 5 +--- nalgebra-sparse/tests/unit_tests/csc.rs | 20 ++++++++++++++++ nalgebra-sparse/tests/unit_tests/csr.rs | 18 +++++++++++++++ nalgebra-sparse/tests/unit_tests/ops.rs | 5 ++-- 7 files changed, 98 insertions(+), 9 deletions(-) diff --git a/nalgebra-sparse/src/csc.rs b/nalgebra-sparse/src/csc.rs index 941fb4c9..a94a5fdc 100644 --- a/nalgebra-sparse/src/csc.rs +++ b/nalgebra-sparse/src/csc.rs @@ -8,6 +8,8 @@ use std::slice::{IterMut, Iter}; use std::ops::Range; use num_traits::Zero; use std::ptr::slice_from_raw_parts_mut; +use crate::csr::CsrMatrix; +use nalgebra::Scalar; /// A CSC representation of a sparse matrix. /// @@ -295,6 +297,15 @@ impl CscMatrix { pub fn pattern(&self) -> &Arc { &self.sparsity_pattern } + + /// Reinterprets the CSC matrix as its transpose represented by a CSR matrix. + /// + /// This operation does not touch the CSC data, and is effectively a no-op. + pub fn transpose_as_csr(self) -> CsrMatrix { + let pattern = self.sparsity_pattern; + let values = self.values; + CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap() + } } impl CscMatrix { @@ -323,6 +334,16 @@ impl CscMatrix { } } +impl CscMatrix + where + T: Scalar + Zero +{ + /// Compute the transpose of the matrix. + pub fn transpose(&self) -> CscMatrix { + CsrMatrix::from(self).transpose_as_csc() + } +} + /// Convert pattern format errors into more meaningful CSC-specific errors. /// /// This ensures that the terminology is consistent: we are talking about rows and columns, diff --git a/nalgebra-sparse/src/csr.rs b/nalgebra-sparse/src/csr.rs index 01d7533e..33348bd5 100644 --- a/nalgebra-sparse/src/csr.rs +++ b/nalgebra-sparse/src/csr.rs @@ -2,11 +2,14 @@ use crate::{SparseFormatError, SparseFormatErrorKind}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; +use crate::csc::CscMatrix; + +use nalgebra::Scalar; +use num_traits::Zero; use std::sync::Arc; use std::slice::{IterMut, Iter}; use std::ops::Range; -use num_traits::Zero; use std::ptr::slice_from_raw_parts_mut; /// A CSR representation of a sparse matrix. @@ -321,6 +324,24 @@ impl CsrMatrix { pub fn index(&self, row_index: usize, col_index: usize) -> T { self.get(row_index, col_index).unwrap() } + + /// Reinterprets the CSR matrix as its transpose represented by a CSC matrix. + /// This operation does not touch the CSR data, and is effectively a no-op. + pub fn transpose_as_csc(self) -> CscMatrix { + let pattern = self.sparsity_pattern; + let values = self.values; + CscMatrix::try_from_pattern_and_values(pattern, values).unwrap() + } +} + +impl CsrMatrix +where + T: Scalar + Zero +{ + /// Compute the transpose of the matrix. + pub fn transpose(&self) -> CsrMatrix { + CscMatrix::from(self).transpose_as_csr() + } } /// Convert pattern format errors into more meaningful CSR-specific errors. diff --git a/nalgebra-sparse/tests/common/mod.rs b/nalgebra-sparse/tests/common/mod.rs index bb77a10a..6e730b7d 100644 --- a/nalgebra-sparse/tests/common/mod.rs +++ b/nalgebra-sparse/tests/common/mod.rs @@ -1,3 +1,8 @@ +use proptest::strategy::Strategy; +use nalgebra_sparse::csr::CsrMatrix; +use nalgebra_sparse::proptest::{csr, csc}; +use nalgebra_sparse::csc::CscMatrix; + #[macro_export] macro_rules! assert_panics { ($e:expr) => {{ @@ -17,4 +22,12 @@ macro_rules! assert_panics { panic!("assert_panics!({}) failed: the expression did not panic.", expr_string); } }}; -} \ No newline at end of file +} + +pub fn csr_strategy() -> impl Strategy> { + csr(-5 ..= 5, 0 ..= 6usize, 0 ..= 6usize, 40) +} + +pub fn csc_strategy() -> impl Strategy> { + csc(-5 ..= 5, 0..=6usize, 0..=6usize, 40) +} diff --git a/nalgebra-sparse/tests/unit_tests/convert_serial.rs b/nalgebra-sparse/tests/unit_tests/convert_serial.rs index 5975966d..9dc13c71 100644 --- a/nalgebra-sparse/tests/unit_tests/convert_serial.rs +++ b/nalgebra-sparse/tests/unit_tests/convert_serial.rs @@ -11,6 +11,7 @@ use proptest::prelude::*; use nalgebra::DMatrix; use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::csc::CscMatrix; +use crate::common::csc_strategy; #[test] fn test_convert_dense_coo() { @@ -276,10 +277,6 @@ fn csr_strategy() -> impl Strategy> { csr(-5 ..= 5, 0..=6usize, 0..=6usize, 40) } -fn csc_strategy() -> impl Strategy> { - csc(-5 ..= 5, 0..=6usize, 0..=6usize, 40) -} - /// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns fn non_zero_csr_strategy() -> impl Strategy> { csr(1 ..= 5, 0..=6usize, 0..=6usize, 40) diff --git a/nalgebra-sparse/tests/unit_tests/csc.rs b/nalgebra-sparse/tests/unit_tests/csc.rs index 140a5db2..a16e1686 100644 --- a/nalgebra-sparse/tests/unit_tests/csc.rs +++ b/nalgebra-sparse/tests/unit_tests/csc.rs @@ -1,5 +1,10 @@ use nalgebra_sparse::csc::CscMatrix; use nalgebra_sparse::SparseFormatErrorKind; +use nalgebra::DMatrix; + +use proptest::prelude::*; + +use crate::common::csc_strategy; #[test] fn csc_matrix_valid_data() { @@ -251,4 +256,19 @@ fn csc_matrix_get_index() { #[test] fn csc_matrix_col_iter() { // TODO +} + +proptest! { + #[test] + fn csc_double_transpose_is_identity(csc in csc_strategy()) { + prop_assert_eq!(csc.transpose().transpose(), csc); + } + + #[test] + fn csc_transpose_agrees_with_dense(csc in csc_strategy()) { + let dense_transpose = DMatrix::from(&csc).transpose(); + let csc_transpose = csc.transpose(); + prop_assert_eq!(dense_transpose, DMatrix::from(&csc_transpose)); + prop_assert_eq!(csc.nnz(), csc_transpose.nnz()); + } } \ No newline at end of file diff --git a/nalgebra-sparse/tests/unit_tests/csr.rs b/nalgebra-sparse/tests/unit_tests/csr.rs index ab6f698e..424bc2c1 100644 --- a/nalgebra-sparse/tests/unit_tests/csr.rs +++ b/nalgebra-sparse/tests/unit_tests/csr.rs @@ -1,5 +1,8 @@ use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::SparseFormatErrorKind; +use nalgebra::DMatrix; +use proptest::prelude::*; +use crate::common::csr_strategy; #[test] fn csr_matrix_valid_data() { @@ -250,5 +253,20 @@ fn csr_matrix_get_index() { #[test] fn csr_matrix_row_iter() { + // TODO +} +proptest! { + #[test] + fn csr_double_transpose_is_identity(csr in csr_strategy()) { + prop_assert_eq!(csr.transpose().transpose(), csr); + } + + #[test] + fn csr_transpose_agrees_with_dense(csr in csr_strategy()) { + let dense_transpose = DMatrix::from(&csr).transpose(); + let csr_transpose = csr.transpose(); + prop_assert_eq!(dense_transpose, DMatrix::from(&csr_transpose)); + prop_assert_eq!(csr.nnz(), csr_transpose.nnz()); + } } \ No newline at end of file diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index 14eb4aec..cc416790 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -13,6 +13,8 @@ use proptest::prelude::*; use std::panic::catch_unwind; use std::sync::Arc; +use crate::common::csr_strategy; + #[test] fn spmv_coo_agrees_with_dense_gemv() { let x = DVector::from_column_slice(&[2, 3, 4, 5]); @@ -89,9 +91,6 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy> }) } -fn csr_strategy() -> impl Strategy> { - csr(-5 ..= 5, 0 ..= 6usize, 0 ..= 6usize, 40) -} fn dense_strategy() -> impl Strategy> { matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6)