Improve checking requirements for sorting column indices

This commit is contained in:
Anton 2021-10-11 22:11:50 +02:00
parent 469765a4e5
commit 4a97989738
5 changed files with 50 additions and 60 deletions

View File

@ -535,10 +535,6 @@ fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseF
use SparsityPatternFormatError::*; use SparsityPatternFormatError::*;
match err { match err {
DifferentValuesIndicesLengths => E::from_kind_and_msg(
K::InvalidStructure,
"Lengths of values and column indices are not equal.",
),
InvalidOffsetArrayLength => E::from_kind_and_msg( InvalidOffsetArrayLength => E::from_kind_and_msg(
K::InvalidStructure, K::InvalidStructure,
"Length of col offset array is not equal to ncols + 1.", "Length of col offset array is not equal to ncols + 1.",

View File

@ -5,13 +5,13 @@
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix}; use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::utils::{apply_permutation, first_and_last_offsets_are_ok}; use crate::utils::apply_permutation;
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind}; use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
use nalgebra::Scalar; use nalgebra::Scalar;
use num_traits::One; use num_traits::One;
use num_traits::Zero; use num_traits::Zero;
use std::slice::{Iter, IterMut}; use std::slice::{Iter, IterMut};
/// A CSR representation of a sparse matrix. /// A CSR representation of a sparse matrix.
@ -190,17 +190,32 @@ impl<T> CsrMatrix<T> {
where where
T: Scalar + Zero, T: Scalar + Zero,
{ {
use SparsityPatternFormatError::*; let count = col_indices.len();
let mut p: Vec<usize> = (0..count).collect();
let mut p: Vec<usize> = (0..col_indices.len()).collect();
if col_indices.len() != values.len() { if col_indices.len() != values.len() {
return (Err(DifferentValuesIndicesLengths)).map_err(pattern_format_error_to_csr_error); return Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"Number of values and column indices must be the same",
));
}
if row_offsets.len() == 0 {
return Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"Number of offsets should be greater than 0",
));
} }
if first_and_last_offsets_are_ok(&row_offsets, &col_indices) {
for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() { for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() {
p[offset..row_offsets[index + 1]].sort_by(|a, b| { let next_offset = row_offsets[index + 1];
if next_offset > count {
return Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"No row offset should be greater than the number of column indices",
));
}
p[offset..next_offset].sort_by(|a, b| {
let x = &col_indices[*a]; let x = &col_indices[*a];
let y = &col_indices[*b]; let y = &col_indices[*b];
x.partial_cmp(y).unwrap() x.partial_cmp(y).unwrap()
@ -211,19 +226,17 @@ impl<T> CsrMatrix<T> {
let sorted_col_indices: Vec<usize> = p.iter().map(|i| col_indices[*i]).collect(); let sorted_col_indices: Vec<usize> = p.iter().map(|i| col_indices[*i]).collect();
// permute values // permute values
let mut output: Vec<T> = vec![T::zero(); p.len()]; let mut sorted_vaues: Vec<T> = vec![T::zero(); count];
apply_permutation(&mut output[..p.len()], &values[..p.len()], &p[..p.len()]); apply_permutation(&mut sorted_vaues[..count], &values[..count], &p[..count]);
return Self::try_from_csr_data( return Self::try_from_csr_data(
num_rows, num_rows,
num_cols, num_cols,
row_offsets, row_offsets,
sorted_col_indices, sorted_col_indices,
output, sorted_vaues,
); );
} }
return (Err(InvalidOffsetFirstLast)).map_err(pattern_format_error_to_csr_error);
}
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values. /// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
/// ///
@ -590,10 +603,6 @@ fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseF
use SparsityPatternFormatError::*; use SparsityPatternFormatError::*;
match err { match err {
DifferentValuesIndicesLengths => E::from_kind_and_msg(
K::InvalidStructure,
"Lengths of values and column indices are not equal.",
),
InvalidOffsetArrayLength => E::from_kind_and_msg( InvalidOffsetArrayLength => E::from_kind_and_msg(
K::InvalidStructure, K::InvalidStructure,
"Length of row offset array is not equal to nrows + 1.", "Length of row offset array is not equal to nrows + 1.",

View File

@ -149,9 +149,9 @@ pub mod csr;
pub mod factorization; pub mod factorization;
pub mod ops; pub mod ops;
pub mod pattern; pub mod pattern;
pub mod utils;
pub(crate) mod cs; pub(crate) mod cs;
pub(crate) mod utils;
#[cfg(feature = "proptest-support")] #[cfg(feature = "proptest-support")]
pub mod proptest; pub mod proptest;

View File

@ -1,6 +1,5 @@
//! Sparsity patterns for CSR and CSC matrices. //! Sparsity patterns for CSR and CSC matrices.
use crate::cs::transpose_cs; use crate::cs::transpose_cs;
use crate::utils::first_and_last_offsets_are_ok;
use crate::SparseFormatError; use crate::SparseFormatError;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
@ -133,9 +132,13 @@ impl SparsityPattern {
} }
// Check that the first and last offsets conform to the specification // Check that the first and last offsets conform to the specification
if !first_and_last_offsets_are_ok(&major_offsets, &minor_indices) { {
let first_offset_ok = *major_offsets.first().unwrap() == 0;
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
if !first_offset_ok || !last_offset_ok {
return Err(InvalidOffsetFirstLast); return Err(InvalidOffsetFirstLast);
} }
}
// Test that each lane has strictly monotonically increasing minor indices, i.e. // Test that each lane has strictly monotonically increasing minor indices, i.e.
// minor indices within a lane are sorted, unique. In addition, each minor index // minor indices within a lane are sorted, unique. In addition, each minor index
@ -261,8 +264,6 @@ impl SparsityPattern {
pub enum SparsityPatternFormatError { pub enum SparsityPatternFormatError {
/// Indicates an invalid number of offsets. /// Indicates an invalid number of offsets.
/// ///
/// Indicates that column indices and values have different lengths.
DifferentValuesIndicesLengths,
/// The number of offsets must be equal to (major_dim + 1). /// The number of offsets must be equal to (major_dim + 1).
InvalidOffsetArrayLength, InvalidOffsetArrayLength,
/// Indicates that the first or last entry in the offset array did not conform to /// Indicates that the first or last entry in the offset array did not conform to
@ -291,8 +292,7 @@ impl From<SparsityPatternFormatError> for SparseFormatError {
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry; use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparsityPatternFormatError::*; use SparsityPatternFormatError::*;
match err { match err {
DifferentValuesIndicesLengths InvalidOffsetArrayLength
| InvalidOffsetArrayLength
| InvalidOffsetFirstLast | InvalidOffsetFirstLast
| NonmonotonicOffsets | NonmonotonicOffsets
| NonmonotonicMinorIndices => { | NonmonotonicMinorIndices => {
@ -313,9 +313,6 @@ impl From<SparsityPatternFormatError> for SparseFormatError {
impl fmt::Display for SparsityPatternFormatError { impl fmt::Display for SparsityPatternFormatError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
SparsityPatternFormatError::DifferentValuesIndicesLengths => {
write!(f, "Lengths of values and column indices are not equal.")
}
SparsityPatternFormatError::InvalidOffsetArrayLength => { SparsityPatternFormatError::InvalidOffsetArrayLength => {
write!(f, "Length of offset array is not equal to (major_dim + 1).") write!(f, "Length of offset array is not equal to (major_dim + 1).")
} }

View File

@ -1,17 +1,5 @@
//! Helper functions for sparse matrix computations //! Helper functions for sparse matrix computations
/// Check that the first and last offsets conform to the specification of a CSR matrix
#[inline]
#[must_use]
pub fn first_and_last_offsets_are_ok(
major_offsets: &Vec<usize>,
minor_indices: &Vec<usize>,
) -> bool {
let first_offset_ok = *major_offsets.first().unwrap() == 0;
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
return first_offset_ok && last_offset_ok;
}
/// permutes entries of in_slice according to permutation slice and puts them to out_slice /// permutes entries of in_slice according to permutation slice and puts them to out_slice
#[inline] #[inline]
pub fn apply_permutation<T: Clone>(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) { pub fn apply_permutation<T: Clone>(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) {