Merge pull request #1312 from arscisca/dev-DefaultTrait

Implement Default trait for sparse matrix types
This commit is contained in:
Sébastien Crozet 2023-11-12 23:27:37 +01:00 committed by GitHub
commit 83eccc6b8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 113 additions and 0 deletions

View File

@ -226,6 +226,15 @@ impl<T> CsMatrix<T> {
} }
} }
impl<T> Default for CsMatrix<T> {
fn default() -> Self {
Self {
sparsity_pattern: Default::default(),
values: vec![],
}
}
}
impl<T: Scalar + One> CsMatrix<T> { impl<T: Scalar + One> CsMatrix<T> {
#[inline] #[inline]
pub fn identity(n: usize) -> Self { pub fn identity(n: usize) -> Self {

View File

@ -574,6 +574,14 @@ impl<T> CscMatrix<T> {
} }
} }
impl<T> Default for CscMatrix<T> {
fn default() -> Self {
Self {
cs: Default::default(),
}
}
}
/// Convert pattern format errors into more meaningful CSC-specific errors. /// Convert pattern format errors into more meaningful CSC-specific errors.
/// ///
/// This ensures that the terminology is consistent: we are talking about rows and columns, /// This ensures that the terminology is consistent: we are talking about rows and columns,

View File

@ -575,6 +575,14 @@ impl<T> CsrMatrix<T> {
} }
} }
impl<T> Default for CsrMatrix<T> {
fn default() -> Self {
Self {
cs: Default::default(),
}
}
}
/// Convert pattern format errors into more meaningful CSR-specific errors. /// Convert pattern format errors into more meaningful CSR-specific errors.
/// ///
/// This ensures that the terminology is consistent: we are talking about rows and columns, /// This ensures that the terminology is consistent: we are talking about rows and columns,

View File

@ -291,6 +291,16 @@ impl SparsityPattern {
} }
} }
impl Default for SparsityPattern {
fn default() -> Self {
Self {
major_offsets: vec![0],
minor_indices: vec![],
minor_dim: 0,
}
}
}
/// Error type for `SparsityPattern` format errors. /// Error type for `SparsityPattern` format errors.
#[non_exhaustive] #[non_exhaustive]
#[derive(Copy, Clone, Debug, PartialEq, Eq)] #[derive(Copy, Clone, Debug, PartialEq, Eq)]

View File

@ -12,6 +12,18 @@ use crate::common::csc_strategy;
use std::collections::HashSet; use std::collections::HashSet;
#[test]
fn csc_matrix_default() {
let matrix: CscMatrix<f32> = CscMatrix::default();
assert_eq!(matrix.nrows(), 0);
assert_eq!(matrix.ncols(), 0);
assert_eq!(matrix.nnz(), 0);
assert_eq!(matrix.values(), &[]);
assert!(matrix.get_entry(0, 0).is_none());
}
#[test] #[test]
fn csc_matrix_valid_data() { fn csc_matrix_valid_data() {
// Construct matrix from valid data and check that selected methods return results // Construct matrix from valid data and check that selected methods return results

View File

@ -12,6 +12,18 @@ use crate::common::csr_strategy;
use std::collections::HashSet; use std::collections::HashSet;
#[test]
fn csr_matrix_default() {
let matrix: CsrMatrix<f32> = CsrMatrix::default();
assert_eq!(matrix.nrows(), 0);
assert_eq!(matrix.ncols(), 0);
assert_eq!(matrix.nnz(), 0);
assert_eq!(matrix.values(), &[]);
assert!(matrix.get_entry(0, 0).is_none());
}
#[test] #[test]
fn csr_matrix_valid_data() { fn csr_matrix_valid_data() {
// Construct matrix from valid data and check that selected methods return results // Construct matrix from valid data and check that selected methods return results

View File

@ -1,5 +1,19 @@
use nalgebra_sparse::pattern::{SparsityPattern, SparsityPatternFormatError}; use nalgebra_sparse::pattern::{SparsityPattern, SparsityPatternFormatError};
#[test]
fn sparsity_pattern_default() {
// Check that the pattern created with `Default::default()` is equivalent to a zero-sized pattern.
let pattern = SparsityPattern::default();
let zero = SparsityPattern::zeros(0, 0);
assert_eq!(pattern.major_dim(), zero.major_dim());
assert_eq!(pattern.minor_dim(), zero.minor_dim());
assert_eq!(pattern.major_offsets(), zero.major_offsets());
assert_eq!(pattern.minor_indices(), zero.minor_indices());
assert_eq!(pattern.nnz(), 0);
}
#[test] #[test]
fn sparsity_pattern_valid_data() { fn sparsity_pattern_valid_data() {
// Construct pattern from valid data and check that selected methods return results // Construct pattern from valid data and check that selected methods return results

View File

@ -31,6 +31,46 @@ pub struct VecStorage<T, R: Dim, C: Dim> {
ncols: C, ncols: C,
} }
impl<T> Default for VecStorage<T, Dyn, Dyn> {
fn default() -> Self {
Self {
data: Vec::new(),
nrows: Dyn::from_usize(0),
ncols: Dyn::from_usize(0),
}
}
}
impl<T, R: DimName> Default for VecStorage<T, R, Dyn> {
fn default() -> Self {
Self {
data: Vec::new(),
nrows: R::name(),
ncols: Dyn::from_usize(0),
}
}
}
impl<T, C: DimName> Default for VecStorage<T, Dyn, C> {
fn default() -> Self {
Self {
data: Vec::new(),
nrows: Dyn::from_usize(0),
ncols: C::name(),
}
}
}
impl<T: Default, R: DimName, C: DimName> Default for VecStorage<T, R, C> {
fn default() -> Self {
let nrows = R::name();
let ncols = C::name();
let mut data = Vec::new();
data.resize_with(nrows.value() * ncols.value(), Default::default);
Self { data, nrows, ncols }
}
}
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
impl<T, R: Dim, C: Dim> Serialize for VecStorage<T, R, C> impl<T, R: Dim, C: Dim> Serialize for VecStorage<T, R, C>
where where