From 4a979897384d0e06beb640236194b7e5f3a3ecd9 Mon Sep 17 00:00:00 2001 From: Anton Date: Mon, 11 Oct 2021 22:11:50 +0200 Subject: [PATCH] Improve checking requirements for sorting column indices --- nalgebra-sparse/src/csc.rs | 4 -- nalgebra-sparse/src/csr.rs | 75 +++++++++++++++++++--------------- nalgebra-sparse/src/lib.rs | 2 +- nalgebra-sparse/src/pattern.rs | 17 ++++---- nalgebra-sparse/src/utils.rs | 12 ------ 5 files changed, 50 insertions(+), 60 deletions(-) diff --git a/nalgebra-sparse/src/csc.rs b/nalgebra-sparse/src/csc.rs index b770bbf3..607cc0cf 100644 --- a/nalgebra-sparse/src/csc.rs +++ b/nalgebra-sparse/src/csc.rs @@ -535,10 +535,6 @@ fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseF use SparsityPatternFormatError::*; match err { - DifferentValuesIndicesLengths => E::from_kind_and_msg( - K::InvalidStructure, - "Lengths of values and column indices are not equal.", - ), InvalidOffsetArrayLength => E::from_kind_and_msg( K::InvalidStructure, "Length of col offset array is not equal to ncols + 1.", diff --git a/nalgebra-sparse/src/csr.rs b/nalgebra-sparse/src/csr.rs index c0bef4d3..beafba05 100644 --- a/nalgebra-sparse/src/csr.rs +++ b/nalgebra-sparse/src/csr.rs @@ -5,13 +5,13 @@ use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix}; use crate::csc::CscMatrix; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; -use crate::utils::{apply_permutation, first_and_last_offsets_are_ok}; +use crate::utils::apply_permutation; use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind}; use nalgebra::Scalar; use num_traits::One; - use num_traits::Zero; + use std::slice::{Iter, IterMut}; /// A CSR representation of a sparse matrix. @@ -190,39 +190,52 @@ impl CsrMatrix { where T: Scalar + Zero, { - use SparsityPatternFormatError::*; - - let mut p: Vec = (0..col_indices.len()).collect(); + let count = col_indices.len(); + let mut p: Vec = (0..count).collect(); if col_indices.len() != values.len() { - return (Err(DifferentValuesIndicesLengths)).map_err(pattern_format_error_to_csr_error); + return Err(SparseFormatError::from_kind_and_msg( + SparseFormatErrorKind::InvalidStructure, + "Number of values and column indices must be the same", + )); } - if first_and_last_offsets_are_ok(&row_offsets, &col_indices) { - for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() { - p[offset..row_offsets[index + 1]].sort_by(|a, b| { - let x = &col_indices[*a]; - let y = &col_indices[*b]; - x.partial_cmp(y).unwrap() - }); + if row_offsets.len() == 0 { + return Err(SparseFormatError::from_kind_and_msg( + SparseFormatErrorKind::InvalidStructure, + "Number of offsets should be greater than 0", + )); + } + + for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() { + let next_offset = row_offsets[index + 1]; + if next_offset > count { + return Err(SparseFormatError::from_kind_and_msg( + SparseFormatErrorKind::InvalidStructure, + "No row offset should be greater than the number of column indices", + )); } - - // permute indices - let sorted_col_indices: Vec = p.iter().map(|i| col_indices[*i]).collect(); - - // permute values - let mut output: Vec = vec![T::zero(); p.len()]; - apply_permutation(&mut output[..p.len()], &values[..p.len()], &p[..p.len()]); - - return Self::try_from_csr_data( - num_rows, - num_cols, - row_offsets, - sorted_col_indices, - output, - ); + p[offset..next_offset].sort_by(|a, b| { + let x = &col_indices[*a]; + let y = &col_indices[*b]; + x.partial_cmp(y).unwrap() + }); } - return (Err(InvalidOffsetFirstLast)).map_err(pattern_format_error_to_csr_error); + + // permute indices + let sorted_col_indices: Vec = p.iter().map(|i| col_indices[*i]).collect(); + + // permute values + let mut sorted_vaues: Vec = vec![T::zero(); count]; + apply_permutation(&mut sorted_vaues[..count], &values[..count], &p[..count]); + + return Self::try_from_csr_data( + num_rows, + num_cols, + row_offsets, + sorted_col_indices, + sorted_vaues, + ); } /// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values. @@ -590,10 +603,6 @@ fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseF use SparsityPatternFormatError::*; match err { - DifferentValuesIndicesLengths => E::from_kind_and_msg( - K::InvalidStructure, - "Lengths of values and column indices are not equal.", - ), InvalidOffsetArrayLength => E::from_kind_and_msg( K::InvalidStructure, "Length of row offset array is not equal to nrows + 1.", diff --git a/nalgebra-sparse/src/lib.rs b/nalgebra-sparse/src/lib.rs index 607a1abf..64331817 100644 --- a/nalgebra-sparse/src/lib.rs +++ b/nalgebra-sparse/src/lib.rs @@ -149,9 +149,9 @@ pub mod csr; pub mod factorization; pub mod ops; pub mod pattern; -pub mod utils; pub(crate) mod cs; +pub(crate) mod utils; #[cfg(feature = "proptest-support")] pub mod proptest; diff --git a/nalgebra-sparse/src/pattern.rs b/nalgebra-sparse/src/pattern.rs index 3e30ee9b..85f6bc1a 100644 --- a/nalgebra-sparse/src/pattern.rs +++ b/nalgebra-sparse/src/pattern.rs @@ -1,6 +1,5 @@ //! Sparsity patterns for CSR and CSC matrices. use crate::cs::transpose_cs; -use crate::utils::first_and_last_offsets_are_ok; use crate::SparseFormatError; use std::error::Error; use std::fmt; @@ -133,8 +132,12 @@ impl SparsityPattern { } // Check that the first and last offsets conform to the specification - if !first_and_last_offsets_are_ok(&major_offsets, &minor_indices) { - return Err(InvalidOffsetFirstLast); + { + let first_offset_ok = *major_offsets.first().unwrap() == 0; + let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len(); + if !first_offset_ok || !last_offset_ok { + return Err(InvalidOffsetFirstLast); + } } // Test that each lane has strictly monotonically increasing minor indices, i.e. @@ -261,8 +264,6 @@ impl SparsityPattern { pub enum SparsityPatternFormatError { /// Indicates an invalid number of offsets. /// - /// Indicates that column indices and values have different lengths. - DifferentValuesIndicesLengths, /// The number of offsets must be equal to (major_dim + 1). InvalidOffsetArrayLength, /// Indicates that the first or last entry in the offset array did not conform to @@ -291,8 +292,7 @@ impl From for SparseFormatError { use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry; use SparsityPatternFormatError::*; match err { - DifferentValuesIndicesLengths - | InvalidOffsetArrayLength + InvalidOffsetArrayLength | InvalidOffsetFirstLast | NonmonotonicOffsets | NonmonotonicMinorIndices => { @@ -313,9 +313,6 @@ impl From for SparseFormatError { impl fmt::Display for SparsityPatternFormatError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - SparsityPatternFormatError::DifferentValuesIndicesLengths => { - write!(f, "Lengths of values and column indices are not equal.") - } SparsityPatternFormatError::InvalidOffsetArrayLength => { write!(f, "Length of offset array is not equal to (major_dim + 1).") } diff --git a/nalgebra-sparse/src/utils.rs b/nalgebra-sparse/src/utils.rs index 411e6e0a..a5da85c5 100644 --- a/nalgebra-sparse/src/utils.rs +++ b/nalgebra-sparse/src/utils.rs @@ -1,17 +1,5 @@ //! Helper functions for sparse matrix computations -/// Check that the first and last offsets conform to the specification of a CSR matrix -#[inline] -#[must_use] -pub fn first_and_last_offsets_are_ok( - major_offsets: &Vec, - minor_indices: &Vec, -) -> bool { - let first_offset_ok = *major_offsets.first().unwrap() == 0; - let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len(); - return first_offset_ok && last_offset_ok; -} - /// permutes entries of in_slice according to permutation slice and puts them to out_slice #[inline] pub fn apply_permutation(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) {