Improve checking requirements for sorting column indices
This commit is contained in:
parent
469765a4e5
commit
4a97989738
@ -535,10 +535,6 @@ fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseF
|
||||
use SparsityPatternFormatError::*;
|
||||
|
||||
match err {
|
||||
DifferentValuesIndicesLengths => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Lengths of values and column indices are not equal.",
|
||||
),
|
||||
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Length of col offset array is not equal to ncols + 1.",
|
||||
|
@ -5,13 +5,13 @@
|
||||
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||
use crate::csc::CscMatrix;
|
||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||
use crate::utils::{apply_permutation, first_and_last_offsets_are_ok};
|
||||
use crate::utils::apply_permutation;
|
||||
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||
|
||||
use nalgebra::Scalar;
|
||||
use num_traits::One;
|
||||
|
||||
use num_traits::Zero;
|
||||
|
||||
use std::slice::{Iter, IterMut};
|
||||
|
||||
/// A CSR representation of a sparse matrix.
|
||||
@ -190,39 +190,52 @@ impl<T> CsrMatrix<T> {
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
{
|
||||
use SparsityPatternFormatError::*;
|
||||
|
||||
let mut p: Vec<usize> = (0..col_indices.len()).collect();
|
||||
let count = col_indices.len();
|
||||
let mut p: Vec<usize> = (0..count).collect();
|
||||
|
||||
if col_indices.len() != values.len() {
|
||||
return (Err(DifferentValuesIndicesLengths)).map_err(pattern_format_error_to_csr_error);
|
||||
return Err(SparseFormatError::from_kind_and_msg(
|
||||
SparseFormatErrorKind::InvalidStructure,
|
||||
"Number of values and column indices must be the same",
|
||||
));
|
||||
}
|
||||
|
||||
if first_and_last_offsets_are_ok(&row_offsets, &col_indices) {
|
||||
for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() {
|
||||
p[offset..row_offsets[index + 1]].sort_by(|a, b| {
|
||||
let x = &col_indices[*a];
|
||||
let y = &col_indices[*b];
|
||||
x.partial_cmp(y).unwrap()
|
||||
});
|
||||
if row_offsets.len() == 0 {
|
||||
return Err(SparseFormatError::from_kind_and_msg(
|
||||
SparseFormatErrorKind::InvalidStructure,
|
||||
"Number of offsets should be greater than 0",
|
||||
));
|
||||
}
|
||||
|
||||
for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() {
|
||||
let next_offset = row_offsets[index + 1];
|
||||
if next_offset > count {
|
||||
return Err(SparseFormatError::from_kind_and_msg(
|
||||
SparseFormatErrorKind::InvalidStructure,
|
||||
"No row offset should be greater than the number of column indices",
|
||||
));
|
||||
}
|
||||
|
||||
// permute indices
|
||||
let sorted_col_indices: Vec<usize> = p.iter().map(|i| col_indices[*i]).collect();
|
||||
|
||||
// permute values
|
||||
let mut output: Vec<T> = vec![T::zero(); p.len()];
|
||||
apply_permutation(&mut output[..p.len()], &values[..p.len()], &p[..p.len()]);
|
||||
|
||||
return Self::try_from_csr_data(
|
||||
num_rows,
|
||||
num_cols,
|
||||
row_offsets,
|
||||
sorted_col_indices,
|
||||
output,
|
||||
);
|
||||
p[offset..next_offset].sort_by(|a, b| {
|
||||
let x = &col_indices[*a];
|
||||
let y = &col_indices[*b];
|
||||
x.partial_cmp(y).unwrap()
|
||||
});
|
||||
}
|
||||
return (Err(InvalidOffsetFirstLast)).map_err(pattern_format_error_to_csr_error);
|
||||
|
||||
// permute indices
|
||||
let sorted_col_indices: Vec<usize> = p.iter().map(|i| col_indices[*i]).collect();
|
||||
|
||||
// permute values
|
||||
let mut sorted_vaues: Vec<T> = vec![T::zero(); count];
|
||||
apply_permutation(&mut sorted_vaues[..count], &values[..count], &p[..count]);
|
||||
|
||||
return Self::try_from_csr_data(
|
||||
num_rows,
|
||||
num_cols,
|
||||
row_offsets,
|
||||
sorted_col_indices,
|
||||
sorted_vaues,
|
||||
);
|
||||
}
|
||||
|
||||
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
|
||||
@ -590,10 +603,6 @@ fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseF
|
||||
use SparsityPatternFormatError::*;
|
||||
|
||||
match err {
|
||||
DifferentValuesIndicesLengths => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Lengths of values and column indices are not equal.",
|
||||
),
|
||||
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Length of row offset array is not equal to nrows + 1.",
|
||||
|
@ -149,9 +149,9 @@ pub mod csr;
|
||||
pub mod factorization;
|
||||
pub mod ops;
|
||||
pub mod pattern;
|
||||
pub mod utils;
|
||||
|
||||
pub(crate) mod cs;
|
||||
pub(crate) mod utils;
|
||||
|
||||
#[cfg(feature = "proptest-support")]
|
||||
pub mod proptest;
|
||||
|
@ -1,6 +1,5 @@
|
||||
//! Sparsity patterns for CSR and CSC matrices.
|
||||
use crate::cs::transpose_cs;
|
||||
use crate::utils::first_and_last_offsets_are_ok;
|
||||
use crate::SparseFormatError;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
@ -133,8 +132,12 @@ impl SparsityPattern {
|
||||
}
|
||||
|
||||
// Check that the first and last offsets conform to the specification
|
||||
if !first_and_last_offsets_are_ok(&major_offsets, &minor_indices) {
|
||||
return Err(InvalidOffsetFirstLast);
|
||||
{
|
||||
let first_offset_ok = *major_offsets.first().unwrap() == 0;
|
||||
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
|
||||
if !first_offset_ok || !last_offset_ok {
|
||||
return Err(InvalidOffsetFirstLast);
|
||||
}
|
||||
}
|
||||
|
||||
// Test that each lane has strictly monotonically increasing minor indices, i.e.
|
||||
@ -261,8 +264,6 @@ impl SparsityPattern {
|
||||
pub enum SparsityPatternFormatError {
|
||||
/// Indicates an invalid number of offsets.
|
||||
///
|
||||
/// Indicates that column indices and values have different lengths.
|
||||
DifferentValuesIndicesLengths,
|
||||
/// The number of offsets must be equal to (major_dim + 1).
|
||||
InvalidOffsetArrayLength,
|
||||
/// Indicates that the first or last entry in the offset array did not conform to
|
||||
@ -291,8 +292,7 @@ impl From<SparsityPatternFormatError> for SparseFormatError {
|
||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||
use SparsityPatternFormatError::*;
|
||||
match err {
|
||||
DifferentValuesIndicesLengths
|
||||
| InvalidOffsetArrayLength
|
||||
InvalidOffsetArrayLength
|
||||
| InvalidOffsetFirstLast
|
||||
| NonmonotonicOffsets
|
||||
| NonmonotonicMinorIndices => {
|
||||
@ -313,9 +313,6 @@ impl From<SparsityPatternFormatError> for SparseFormatError {
|
||||
impl fmt::Display for SparsityPatternFormatError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
SparsityPatternFormatError::DifferentValuesIndicesLengths => {
|
||||
write!(f, "Lengths of values and column indices are not equal.")
|
||||
}
|
||||
SparsityPatternFormatError::InvalidOffsetArrayLength => {
|
||||
write!(f, "Length of offset array is not equal to (major_dim + 1).")
|
||||
}
|
||||
|
@ -1,17 +1,5 @@
|
||||
//! Helper functions for sparse matrix computations
|
||||
|
||||
/// Check that the first and last offsets conform to the specification of a CSR matrix
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn first_and_last_offsets_are_ok(
|
||||
major_offsets: &Vec<usize>,
|
||||
minor_indices: &Vec<usize>,
|
||||
) -> bool {
|
||||
let first_offset_ok = *major_offsets.first().unwrap() == 0;
|
||||
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
|
||||
return first_offset_ok && last_offset_ok;
|
||||
}
|
||||
|
||||
/// permutes entries of in_slice according to permutation slice and puts them to out_slice
|
||||
#[inline]
|
||||
pub fn apply_permutation<T: Clone>(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) {
|
||||
|
Loading…
Reference in New Issue
Block a user