Rename some Csr/Csc/SparsityPattern methods

This commit is contained in:
Andreas Longva 2021-01-25 16:04:29 +01:00
parent cf1bd284f1
commit 7b6333e9d1
8 changed files with 30 additions and 22 deletions

View File

@ -26,7 +26,7 @@ impl<T> CsMatrix<T> {
#[inline] #[inline]
pub fn new(major_dim: usize, minor_dim: usize) -> Self { pub fn new(major_dim: usize, minor_dim: usize) -> Self {
Self { Self {
sparsity_pattern: SparsityPattern::new(major_dim, minor_dim), sparsity_pattern: SparsityPattern::zeros(major_dim, minor_dim),
values: vec![], values: vec![],
} }
} }
@ -185,6 +185,15 @@ impl<T> CsMatrix<T> {
Self::from_pattern_and_values(new_pattern, new_values) Self::from_pattern_and_values(new_pattern, new_values)
} }
/// Returns the diagonal of the matrix as a sparse matrix.
pub fn diagonal_as_matrix(&self) -> Self
where
T: Clone
{
// TODO: This might be faster with a binary search for each diagonal entry
self.filter(|i, j, _| i == j)
}
} }
impl<T: Scalar + One> CsMatrix<T> { impl<T: Scalar + One> CsMatrix<T> {

View File

@ -1,4 +1,7 @@
//! An implementation of the CSC sparse matrix format. //! An implementation of the CSC sparse matrix format.
//!
//! This is the module-level documentation. See [`CscMatrix`] for the main documentation of the
//! CSC implementation.
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut}; use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
@ -125,7 +128,7 @@ pub struct CscMatrix<T> {
impl<T> CscMatrix<T> { impl<T> CscMatrix<T> {
/// Create a zero CSC matrix with no explicitly stored entries. /// Create a zero CSC matrix with no explicitly stored entries.
pub fn new(nrows: usize, ncols: usize) -> Self { pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self { Self {
cs: CsMatrix::new(ncols, nrows) cs: CsMatrix::new(ncols, nrows)
} }
@ -469,11 +472,11 @@ impl<T> CscMatrix<T> {
} }
/// Returns the diagonal of the matrix as a sparse matrix. /// Returns the diagonal of the matrix as a sparse matrix.
pub fn diagonal_as_matrix(&self) -> Self pub fn diagonal_as_csc(&self) -> Self
where where
T: Clone T: Clone
{ {
self.filter(|i, j, _| i == j) Self { cs: self.cs.diagonal_as_matrix() }
} }
} }

View File

@ -1,4 +1,7 @@
//! An implementation of the CSR sparse matrix format. //! An implementation of the CSR sparse matrix format.
//!
//! This is the module-level documentation. See [`CsrMatrix`] for the main documentation of the
//! CSC implementation.
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut}; use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
@ -125,7 +128,7 @@ pub struct CsrMatrix<T> {
impl<T> CsrMatrix<T> { impl<T> CsrMatrix<T> {
/// Create a zero CSR matrix with no explicitly stored entries. /// Create a zero CSR matrix with no explicitly stored entries.
pub fn new(nrows: usize, ncols: usize) -> Self { pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self { Self {
cs: CsMatrix::new(nrows, ncols) cs: CsMatrix::new(nrows, ncols)
} }
@ -469,11 +472,11 @@ impl<T> CsrMatrix<T> {
} }
/// Returns the diagonal of the matrix as a sparse matrix. /// Returns the diagonal of the matrix as a sparse matrix.
pub fn diagonal_as_matrix(&self) -> Self pub fn diagonal_as_csr(&self) -> Self
where where
T: Clone T: Clone
{ {
self.filter(|i, j, _| i == j) Self { cs: self.cs.diagonal_as_matrix() }
} }
} }

View File

@ -128,13 +128,6 @@
//! assert_matrix_eq!(y, y_expected, comp = abs, tol = 1e-9); //! assert_matrix_eq!(y, y_expected, comp = abs, tol = 1e-9);
//! } //! }
//! ``` //! ```
//!
//! TODO: Write docs on the following:
//!
//! - Overall design ("easy API" vs. "expert" API etc.)
//! - Conversions (From, explicit "expert" API etc.)
//! - Matrix ops design
//! - Proptest and matrixcompare integrations
#![deny(non_camel_case_types)] #![deny(non_camel_case_types)]
#![deny(unused_parens)] #![deny(unused_parens)]
#![deny(non_upper_case_globals)] #![deny(non_upper_case_globals)]

View File

@ -50,7 +50,7 @@ pub struct SparsityPattern {
impl SparsityPattern { impl SparsityPattern {
/// Create a sparsity pattern of the given dimensions without explicitly stored entries. /// Create a sparsity pattern of the given dimensions without explicitly stored entries.
pub fn new(major_dim: usize, minor_dim: usize) -> Self { pub fn zeros(major_dim: usize, minor_dim: usize) -> Self {
Self { Self {
major_offsets: vec![0; major_dim + 1], major_offsets: vec![0; major_dim + 1],
minor_indices: vec![], minor_indices: vec![],

View File

@ -21,7 +21,7 @@ fn csc_matrix_valid_data() {
let values = Vec::<i32>::new(); let values = Vec::<i32>::new();
let mut matrix = CscMatrix::try_from_csc_data(2, 3, offsets, indices, values).unwrap(); let mut matrix = CscMatrix::try_from_csc_data(2, 3, offsets, indices, values).unwrap();
assert_eq!(matrix, CscMatrix::new(2, 3)); assert_eq!(matrix, CscMatrix::zeros(2, 3));
assert_eq!(matrix.nrows(), 2); assert_eq!(matrix.nrows(), 2);
assert_eq!(matrix.ncols(), 3); assert_eq!(matrix.ncols(), 3);
@ -317,8 +317,8 @@ proptest! {
} }
#[test] #[test]
fn csc_diagonal_as_matrix(csc in csc_strategy()) { fn csc_diagonal_as_csc(csc in csc_strategy()) {
let d = csc.diagonal_as_matrix(); let d = csc.diagonal_as_csc();
let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect(); let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect();
let csc_diagonal_entries: HashSet<_> = csc let csc_diagonal_entries: HashSet<_> = csc
.triplet_iter() .triplet_iter()

View File

@ -22,7 +22,7 @@ fn csr_matrix_valid_data() {
let values = Vec::<i32>::new(); let values = Vec::<i32>::new();
let mut matrix = CsrMatrix::try_from_csr_data(3, 2, offsets, indices, values).unwrap(); let mut matrix = CsrMatrix::try_from_csr_data(3, 2, offsets, indices, values).unwrap();
assert_eq!(matrix, CsrMatrix::new(3, 2)); assert_eq!(matrix, CsrMatrix::zeros(3, 2));
assert_eq!(matrix.nrows(), 3); assert_eq!(matrix.nrows(), 3);
assert_eq!(matrix.ncols(), 2); assert_eq!(matrix.ncols(), 2);
@ -318,8 +318,8 @@ proptest! {
} }
#[test] #[test]
fn csr_diagonal_as_matrix(csr in csr_strategy()) { fn csr_diagonal_as_csr(csr in csr_strategy()) {
let d = csr.diagonal_as_matrix(); let d = csr.diagonal_as_csr();
let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect(); let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect();
let csr_diagonal_entries: HashSet<_> = csr let csr_diagonal_entries: HashSet<_> = csr
.triplet_iter() .triplet_iter()

View File

@ -23,7 +23,7 @@ fn sparsity_pattern_valid_data() {
assert_eq!(pattern.lane(2), &[]); assert_eq!(pattern.lane(2), &[]);
assert!(pattern.entries().next().is_none()); assert!(pattern.entries().next().is_none());
assert_eq!(pattern, SparsityPattern::new(3, 2)); assert_eq!(pattern, SparsityPattern::zeros(3, 2));
let (offsets, indices) = pattern.disassemble(); let (offsets, indices) = pattern.disassemble();
assert_eq!(offsets, vec![0, 0, 0, 0]); assert_eq!(offsets, vec![0, 0, 0, 0]);