Improve checking requirements for sorting column indices

This commit is contained in:
Anton 2021-10-11 22:11:50 +02:00
parent 469765a4e5
commit 4a97989738
5 changed files with 50 additions and 60 deletions

View File

@ -535,10 +535,6 @@ fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseF
use SparsityPatternFormatError::*;
match err {
DifferentValuesIndicesLengths => E::from_kind_and_msg(
K::InvalidStructure,
"Lengths of values and column indices are not equal.",
),
InvalidOffsetArrayLength => E::from_kind_and_msg(
K::InvalidStructure,
"Length of col offset array is not equal to ncols + 1.",

View File

@ -5,13 +5,13 @@
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
use crate::csc::CscMatrix;
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::utils::{apply_permutation, first_and_last_offsets_are_ok};
use crate::utils::apply_permutation;
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
use nalgebra::Scalar;
use num_traits::One;
use num_traits::Zero;
use std::slice::{Iter, IterMut};
/// A CSR representation of a sparse matrix.
@ -190,39 +190,52 @@ impl<T> CsrMatrix<T> {
where
T: Scalar + Zero,
{
use SparsityPatternFormatError::*;
let mut p: Vec<usize> = (0..col_indices.len()).collect();
let count = col_indices.len();
let mut p: Vec<usize> = (0..count).collect();
if col_indices.len() != values.len() {
return (Err(DifferentValuesIndicesLengths)).map_err(pattern_format_error_to_csr_error);
return Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"Number of values and column indices must be the same",
));
}
if first_and_last_offsets_are_ok(&row_offsets, &col_indices) {
for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() {
p[offset..row_offsets[index + 1]].sort_by(|a, b| {
let x = &col_indices[*a];
let y = &col_indices[*b];
x.partial_cmp(y).unwrap()
});
if row_offsets.len() == 0 {
return Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"Number of offsets should be greater than 0",
));
}
for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() {
let next_offset = row_offsets[index + 1];
if next_offset > count {
return Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"No row offset should be greater than the number of column indices",
));
}
// permute indices
let sorted_col_indices: Vec<usize> = p.iter().map(|i| col_indices[*i]).collect();
// permute values
let mut output: Vec<T> = vec![T::zero(); p.len()];
apply_permutation(&mut output[..p.len()], &values[..p.len()], &p[..p.len()]);
return Self::try_from_csr_data(
num_rows,
num_cols,
row_offsets,
sorted_col_indices,
output,
);
p[offset..next_offset].sort_by(|a, b| {
let x = &col_indices[*a];
let y = &col_indices[*b];
x.partial_cmp(y).unwrap()
});
}
return (Err(InvalidOffsetFirstLast)).map_err(pattern_format_error_to_csr_error);
// permute indices
let sorted_col_indices: Vec<usize> = p.iter().map(|i| col_indices[*i]).collect();
// permute values
let mut sorted_vaues: Vec<T> = vec![T::zero(); count];
apply_permutation(&mut sorted_vaues[..count], &values[..count], &p[..count]);
return Self::try_from_csr_data(
num_rows,
num_cols,
row_offsets,
sorted_col_indices,
sorted_vaues,
);
}
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
@ -590,10 +603,6 @@ fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseF
use SparsityPatternFormatError::*;
match err {
DifferentValuesIndicesLengths => E::from_kind_and_msg(
K::InvalidStructure,
"Lengths of values and column indices are not equal.",
),
InvalidOffsetArrayLength => E::from_kind_and_msg(
K::InvalidStructure,
"Length of row offset array is not equal to nrows + 1.",

View File

@ -149,9 +149,9 @@ pub mod csr;
pub mod factorization;
pub mod ops;
pub mod pattern;
pub mod utils;
pub(crate) mod cs;
pub(crate) mod utils;
#[cfg(feature = "proptest-support")]
pub mod proptest;

View File

@ -1,6 +1,5 @@
//! Sparsity patterns for CSR and CSC matrices.
use crate::cs::transpose_cs;
use crate::utils::first_and_last_offsets_are_ok;
use crate::SparseFormatError;
use std::error::Error;
use std::fmt;
@ -133,8 +132,12 @@ impl SparsityPattern {
}
// Check that the first and last offsets conform to the specification
if !first_and_last_offsets_are_ok(&major_offsets, &minor_indices) {
return Err(InvalidOffsetFirstLast);
{
let first_offset_ok = *major_offsets.first().unwrap() == 0;
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
if !first_offset_ok || !last_offset_ok {
return Err(InvalidOffsetFirstLast);
}
}
// Test that each lane has strictly monotonically increasing minor indices, i.e.
@ -261,8 +264,6 @@ impl SparsityPattern {
pub enum SparsityPatternFormatError {
/// Indicates an invalid number of offsets.
///
/// Indicates that column indices and values have different lengths.
DifferentValuesIndicesLengths,
/// The number of offsets must be equal to (major_dim + 1).
InvalidOffsetArrayLength,
/// Indicates that the first or last entry in the offset array did not conform to
@ -291,8 +292,7 @@ impl From<SparsityPatternFormatError> for SparseFormatError {
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparsityPatternFormatError::*;
match err {
DifferentValuesIndicesLengths
| InvalidOffsetArrayLength
InvalidOffsetArrayLength
| InvalidOffsetFirstLast
| NonmonotonicOffsets
| NonmonotonicMinorIndices => {
@ -313,9 +313,6 @@ impl From<SparsityPatternFormatError> for SparseFormatError {
impl fmt::Display for SparsityPatternFormatError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SparsityPatternFormatError::DifferentValuesIndicesLengths => {
write!(f, "Lengths of values and column indices are not equal.")
}
SparsityPatternFormatError::InvalidOffsetArrayLength => {
write!(f, "Length of offset array is not equal to (major_dim + 1).")
}

View File

@ -1,17 +1,5 @@
//! Helper functions for sparse matrix computations
/// Check that the first and last offsets conform to the specification of a CSR matrix
#[inline]
#[must_use]
pub fn first_and_last_offsets_are_ok(
major_offsets: &Vec<usize>,
minor_indices: &Vec<usize>,
) -> bool {
let first_offset_ok = *major_offsets.first().unwrap() == 0;
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
return first_offset_ok && last_offset_ok;
}
/// permutes entries of in_slice according to permutation slice and puts them to out_slice
#[inline]
pub fn apply_permutation<T: Clone>(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) {