Implement Csr/CscMatrix::transpose()
This commit is contained in:
parent
8b7b836a37
commit
830df6d07b
|
@ -8,6 +8,8 @@ use std::slice::{IterMut, Iter};
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
use num_traits::Zero;
|
use num_traits::Zero;
|
||||||
use std::ptr::slice_from_raw_parts_mut;
|
use std::ptr::slice_from_raw_parts_mut;
|
||||||
|
use crate::csr::CsrMatrix;
|
||||||
|
use nalgebra::Scalar;
|
||||||
|
|
||||||
/// A CSC representation of a sparse matrix.
|
/// A CSC representation of a sparse matrix.
|
||||||
///
|
///
|
||||||
|
@ -295,6 +297,15 @@ impl<T> CscMatrix<T> {
|
||||||
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
||||||
&self.sparsity_pattern
|
&self.sparsity_pattern
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Reinterprets the CSC matrix as its transpose represented by a CSR matrix.
|
||||||
|
///
|
||||||
|
/// This operation does not touch the CSC data, and is effectively a no-op.
|
||||||
|
pub fn transpose_as_csr(self) -> CsrMatrix<T> {
|
||||||
|
let pattern = self.sparsity_pattern;
|
||||||
|
let values = self.values;
|
||||||
|
CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Clone + Zero> CscMatrix<T> {
|
impl<T: Clone + Zero> CscMatrix<T> {
|
||||||
|
@ -323,6 +334,16 @@ impl<T: Clone + Zero> CscMatrix<T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T> CscMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero
|
||||||
|
{
|
||||||
|
/// Compute the transpose of the matrix.
|
||||||
|
pub fn transpose(&self) -> CscMatrix<T> {
|
||||||
|
CsrMatrix::from(self).transpose_as_csc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Convert pattern format errors into more meaningful CSC-specific errors.
|
/// Convert pattern format errors into more meaningful CSC-specific errors.
|
||||||
///
|
///
|
||||||
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
||||||
|
|
|
@ -2,11 +2,14 @@
|
||||||
|
|
||||||
use crate::{SparseFormatError, SparseFormatErrorKind};
|
use crate::{SparseFormatError, SparseFormatErrorKind};
|
||||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
|
use crate::csc::CscMatrix;
|
||||||
|
|
||||||
|
use nalgebra::Scalar;
|
||||||
|
use num_traits::Zero;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::slice::{IterMut, Iter};
|
use std::slice::{IterMut, Iter};
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
use num_traits::Zero;
|
|
||||||
use std::ptr::slice_from_raw_parts_mut;
|
use std::ptr::slice_from_raw_parts_mut;
|
||||||
|
|
||||||
/// A CSR representation of a sparse matrix.
|
/// A CSR representation of a sparse matrix.
|
||||||
|
@ -321,6 +324,24 @@ impl<T: Clone + Zero> CsrMatrix<T> {
|
||||||
pub fn index(&self, row_index: usize, col_index: usize) -> T {
|
pub fn index(&self, row_index: usize, col_index: usize) -> T {
|
||||||
self.get(row_index, col_index).unwrap()
|
self.get(row_index, col_index).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Reinterprets the CSR matrix as its transpose represented by a CSC matrix.
|
||||||
|
/// This operation does not touch the CSR data, and is effectively a no-op.
|
||||||
|
pub fn transpose_as_csc(self) -> CscMatrix<T> {
|
||||||
|
let pattern = self.sparsity_pattern;
|
||||||
|
let values = self.values;
|
||||||
|
CscMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> CsrMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Zero
|
||||||
|
{
|
||||||
|
/// Compute the transpose of the matrix.
|
||||||
|
pub fn transpose(&self) -> CsrMatrix<T> {
|
||||||
|
CscMatrix::from(self).transpose_as_csr()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert pattern format errors into more meaningful CSR-specific errors.
|
/// Convert pattern format errors into more meaningful CSR-specific errors.
|
||||||
|
|
|
@ -1,3 +1,8 @@
|
||||||
|
use proptest::strategy::Strategy;
|
||||||
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
use nalgebra_sparse::proptest::{csr, csc};
|
||||||
|
use nalgebra_sparse::csc::CscMatrix;
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! assert_panics {
|
macro_rules! assert_panics {
|
||||||
($e:expr) => {{
|
($e:expr) => {{
|
||||||
|
@ -18,3 +23,11 @@ macro_rules! assert_panics {
|
||||||
}
|
}
|
||||||
}};
|
}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
||||||
|
csr(-5 ..= 5, 0 ..= 6usize, 0 ..= 6usize, 40)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn csc_strategy() -> impl Strategy<Value=CscMatrix<i32>> {
|
||||||
|
csc(-5 ..= 5, 0..=6usize, 0..=6usize, 40)
|
||||||
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ use proptest::prelude::*;
|
||||||
use nalgebra::DMatrix;
|
use nalgebra::DMatrix;
|
||||||
use nalgebra_sparse::csr::CsrMatrix;
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
use nalgebra_sparse::csc::CscMatrix;
|
use nalgebra_sparse::csc::CscMatrix;
|
||||||
|
use crate::common::csc_strategy;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_convert_dense_coo() {
|
fn test_convert_dense_coo() {
|
||||||
|
@ -276,10 +277,6 @@ fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
||||||
csr(-5 ..= 5, 0..=6usize, 0..=6usize, 40)
|
csr(-5 ..= 5, 0..=6usize, 0..=6usize, 40)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn csc_strategy() -> impl Strategy<Value=CscMatrix<i32>> {
|
|
||||||
csc(-5 ..= 5, 0..=6usize, 0..=6usize, 40)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns
|
/// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns
|
||||||
fn non_zero_csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
fn non_zero_csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
||||||
csr(1 ..= 5, 0..=6usize, 0..=6usize, 40)
|
csr(1 ..= 5, 0..=6usize, 0..=6usize, 40)
|
||||||
|
|
|
@ -1,5 +1,10 @@
|
||||||
use nalgebra_sparse::csc::CscMatrix;
|
use nalgebra_sparse::csc::CscMatrix;
|
||||||
use nalgebra_sparse::SparseFormatErrorKind;
|
use nalgebra_sparse::SparseFormatErrorKind;
|
||||||
|
use nalgebra::DMatrix;
|
||||||
|
|
||||||
|
use proptest::prelude::*;
|
||||||
|
|
||||||
|
use crate::common::csc_strategy;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn csc_matrix_valid_data() {
|
fn csc_matrix_valid_data() {
|
||||||
|
@ -252,3 +257,18 @@ fn csc_matrix_get_index() {
|
||||||
fn csc_matrix_col_iter() {
|
fn csc_matrix_col_iter() {
|
||||||
// TODO
|
// TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
|
proptest! {
|
||||||
|
#[test]
|
||||||
|
fn csc_double_transpose_is_identity(csc in csc_strategy()) {
|
||||||
|
prop_assert_eq!(csc.transpose().transpose(), csc);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_transpose_agrees_with_dense(csc in csc_strategy()) {
|
||||||
|
let dense_transpose = DMatrix::from(&csc).transpose();
|
||||||
|
let csc_transpose = csc.transpose();
|
||||||
|
prop_assert_eq!(dense_transpose, DMatrix::from(&csc_transpose));
|
||||||
|
prop_assert_eq!(csc.nnz(), csc_transpose.nnz());
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,8 @@
|
||||||
use nalgebra_sparse::csr::CsrMatrix;
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
use nalgebra_sparse::SparseFormatErrorKind;
|
use nalgebra_sparse::SparseFormatErrorKind;
|
||||||
|
use nalgebra::DMatrix;
|
||||||
|
use proptest::prelude::*;
|
||||||
|
use crate::common::csr_strategy;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn csr_matrix_valid_data() {
|
fn csr_matrix_valid_data() {
|
||||||
|
@ -250,5 +253,20 @@ fn csr_matrix_get_index() {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn csr_matrix_row_iter() {
|
fn csr_matrix_row_iter() {
|
||||||
|
// TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
proptest! {
|
||||||
|
#[test]
|
||||||
|
fn csr_double_transpose_is_identity(csr in csr_strategy()) {
|
||||||
|
prop_assert_eq!(csr.transpose().transpose(), csr);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_transpose_agrees_with_dense(csr in csr_strategy()) {
|
||||||
|
let dense_transpose = DMatrix::from(&csr).transpose();
|
||||||
|
let csr_transpose = csr.transpose();
|
||||||
|
prop_assert_eq!(dense_transpose, DMatrix::from(&csr_transpose));
|
||||||
|
prop_assert_eq!(csr.nnz(), csr_transpose.nnz());
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -13,6 +13,8 @@ use proptest::prelude::*;
|
||||||
use std::panic::catch_unwind;
|
use std::panic::catch_unwind;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::common::csr_strategy;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spmv_coo_agrees_with_dense_gemv() {
|
fn spmv_coo_agrees_with_dense_gemv() {
|
||||||
let x = DVector::from_column_slice(&[2, 3, 4, 5]);
|
let x = DVector::from_column_slice(&[2, 3, 4, 5]);
|
||||||
|
@ -89,9 +91,6 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
|
||||||
csr(-5 ..= 5, 0 ..= 6usize, 0 ..= 6usize, 40)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
||||||
matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6)
|
matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6)
|
||||||
|
|
Loading…
Reference in New Issue