Improve checking requirements for sorting column indices
This commit is contained in:
parent
469765a4e5
commit
4a97989738
|
@ -535,10 +535,6 @@ fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseF
|
||||||
use SparsityPatternFormatError::*;
|
use SparsityPatternFormatError::*;
|
||||||
|
|
||||||
match err {
|
match err {
|
||||||
DifferentValuesIndicesLengths => E::from_kind_and_msg(
|
|
||||||
K::InvalidStructure,
|
|
||||||
"Lengths of values and column indices are not equal.",
|
|
||||||
),
|
|
||||||
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||||
K::InvalidStructure,
|
K::InvalidStructure,
|
||||||
"Length of col offset array is not equal to ncols + 1.",
|
"Length of col offset array is not equal to ncols + 1.",
|
||||||
|
|
|
@ -5,13 +5,13 @@
|
||||||
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
use crate::utils::{apply_permutation, first_and_last_offsets_are_ok};
|
use crate::utils::apply_permutation;
|
||||||
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||||
|
|
||||||
use nalgebra::Scalar;
|
use nalgebra::Scalar;
|
||||||
use num_traits::One;
|
use num_traits::One;
|
||||||
|
|
||||||
use num_traits::Zero;
|
use num_traits::Zero;
|
||||||
|
|
||||||
use std::slice::{Iter, IterMut};
|
use std::slice::{Iter, IterMut};
|
||||||
|
|
||||||
/// A CSR representation of a sparse matrix.
|
/// A CSR representation of a sparse matrix.
|
||||||
|
@ -190,17 +190,32 @@ impl<T> CsrMatrix<T> {
|
||||||
where
|
where
|
||||||
T: Scalar + Zero,
|
T: Scalar + Zero,
|
||||||
{
|
{
|
||||||
use SparsityPatternFormatError::*;
|
let count = col_indices.len();
|
||||||
|
let mut p: Vec<usize> = (0..count).collect();
|
||||||
let mut p: Vec<usize> = (0..col_indices.len()).collect();
|
|
||||||
|
|
||||||
if col_indices.len() != values.len() {
|
if col_indices.len() != values.len() {
|
||||||
return (Err(DifferentValuesIndicesLengths)).map_err(pattern_format_error_to_csr_error);
|
return Err(SparseFormatError::from_kind_and_msg(
|
||||||
|
SparseFormatErrorKind::InvalidStructure,
|
||||||
|
"Number of values and column indices must be the same",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if row_offsets.len() == 0 {
|
||||||
|
return Err(SparseFormatError::from_kind_and_msg(
|
||||||
|
SparseFormatErrorKind::InvalidStructure,
|
||||||
|
"Number of offsets should be greater than 0",
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if first_and_last_offsets_are_ok(&row_offsets, &col_indices) {
|
|
||||||
for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() {
|
for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() {
|
||||||
p[offset..row_offsets[index + 1]].sort_by(|a, b| {
|
let next_offset = row_offsets[index + 1];
|
||||||
|
if next_offset > count {
|
||||||
|
return Err(SparseFormatError::from_kind_and_msg(
|
||||||
|
SparseFormatErrorKind::InvalidStructure,
|
||||||
|
"No row offset should be greater than the number of column indices",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
p[offset..next_offset].sort_by(|a, b| {
|
||||||
let x = &col_indices[*a];
|
let x = &col_indices[*a];
|
||||||
let y = &col_indices[*b];
|
let y = &col_indices[*b];
|
||||||
x.partial_cmp(y).unwrap()
|
x.partial_cmp(y).unwrap()
|
||||||
|
@ -211,19 +226,17 @@ impl<T> CsrMatrix<T> {
|
||||||
let sorted_col_indices: Vec<usize> = p.iter().map(|i| col_indices[*i]).collect();
|
let sorted_col_indices: Vec<usize> = p.iter().map(|i| col_indices[*i]).collect();
|
||||||
|
|
||||||
// permute values
|
// permute values
|
||||||
let mut output: Vec<T> = vec![T::zero(); p.len()];
|
let mut sorted_vaues: Vec<T> = vec![T::zero(); count];
|
||||||
apply_permutation(&mut output[..p.len()], &values[..p.len()], &p[..p.len()]);
|
apply_permutation(&mut sorted_vaues[..count], &values[..count], &p[..count]);
|
||||||
|
|
||||||
return Self::try_from_csr_data(
|
return Self::try_from_csr_data(
|
||||||
num_rows,
|
num_rows,
|
||||||
num_cols,
|
num_cols,
|
||||||
row_offsets,
|
row_offsets,
|
||||||
sorted_col_indices,
|
sorted_col_indices,
|
||||||
output,
|
sorted_vaues,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return (Err(InvalidOffsetFirstLast)).map_err(pattern_format_error_to_csr_error);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
|
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
|
||||||
///
|
///
|
||||||
|
@ -590,10 +603,6 @@ fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseF
|
||||||
use SparsityPatternFormatError::*;
|
use SparsityPatternFormatError::*;
|
||||||
|
|
||||||
match err {
|
match err {
|
||||||
DifferentValuesIndicesLengths => E::from_kind_and_msg(
|
|
||||||
K::InvalidStructure,
|
|
||||||
"Lengths of values and column indices are not equal.",
|
|
||||||
),
|
|
||||||
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||||
K::InvalidStructure,
|
K::InvalidStructure,
|
||||||
"Length of row offset array is not equal to nrows + 1.",
|
"Length of row offset array is not equal to nrows + 1.",
|
||||||
|
|
|
@ -149,9 +149,9 @@ pub mod csr;
|
||||||
pub mod factorization;
|
pub mod factorization;
|
||||||
pub mod ops;
|
pub mod ops;
|
||||||
pub mod pattern;
|
pub mod pattern;
|
||||||
pub mod utils;
|
|
||||||
|
|
||||||
pub(crate) mod cs;
|
pub(crate) mod cs;
|
||||||
|
pub(crate) mod utils;
|
||||||
|
|
||||||
#[cfg(feature = "proptest-support")]
|
#[cfg(feature = "proptest-support")]
|
||||||
pub mod proptest;
|
pub mod proptest;
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
//! Sparsity patterns for CSR and CSC matrices.
|
//! Sparsity patterns for CSR and CSC matrices.
|
||||||
use crate::cs::transpose_cs;
|
use crate::cs::transpose_cs;
|
||||||
use crate::utils::first_and_last_offsets_are_ok;
|
|
||||||
use crate::SparseFormatError;
|
use crate::SparseFormatError;
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
@ -133,9 +132,13 @@ impl SparsityPattern {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that the first and last offsets conform to the specification
|
// Check that the first and last offsets conform to the specification
|
||||||
if !first_and_last_offsets_are_ok(&major_offsets, &minor_indices) {
|
{
|
||||||
|
let first_offset_ok = *major_offsets.first().unwrap() == 0;
|
||||||
|
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
|
||||||
|
if !first_offset_ok || !last_offset_ok {
|
||||||
return Err(InvalidOffsetFirstLast);
|
return Err(InvalidOffsetFirstLast);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Test that each lane has strictly monotonically increasing minor indices, i.e.
|
// Test that each lane has strictly monotonically increasing minor indices, i.e.
|
||||||
// minor indices within a lane are sorted, unique. In addition, each minor index
|
// minor indices within a lane are sorted, unique. In addition, each minor index
|
||||||
|
@ -261,8 +264,6 @@ impl SparsityPattern {
|
||||||
pub enum SparsityPatternFormatError {
|
pub enum SparsityPatternFormatError {
|
||||||
/// Indicates an invalid number of offsets.
|
/// Indicates an invalid number of offsets.
|
||||||
///
|
///
|
||||||
/// Indicates that column indices and values have different lengths.
|
|
||||||
DifferentValuesIndicesLengths,
|
|
||||||
/// The number of offsets must be equal to (major_dim + 1).
|
/// The number of offsets must be equal to (major_dim + 1).
|
||||||
InvalidOffsetArrayLength,
|
InvalidOffsetArrayLength,
|
||||||
/// Indicates that the first or last entry in the offset array did not conform to
|
/// Indicates that the first or last entry in the offset array did not conform to
|
||||||
|
@ -291,8 +292,7 @@ impl From<SparsityPatternFormatError> for SparseFormatError {
|
||||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||||
use SparsityPatternFormatError::*;
|
use SparsityPatternFormatError::*;
|
||||||
match err {
|
match err {
|
||||||
DifferentValuesIndicesLengths
|
InvalidOffsetArrayLength
|
||||||
| InvalidOffsetArrayLength
|
|
||||||
| InvalidOffsetFirstLast
|
| InvalidOffsetFirstLast
|
||||||
| NonmonotonicOffsets
|
| NonmonotonicOffsets
|
||||||
| NonmonotonicMinorIndices => {
|
| NonmonotonicMinorIndices => {
|
||||||
|
@ -313,9 +313,6 @@ impl From<SparsityPatternFormatError> for SparseFormatError {
|
||||||
impl fmt::Display for SparsityPatternFormatError {
|
impl fmt::Display for SparsityPatternFormatError {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
SparsityPatternFormatError::DifferentValuesIndicesLengths => {
|
|
||||||
write!(f, "Lengths of values and column indices are not equal.")
|
|
||||||
}
|
|
||||||
SparsityPatternFormatError::InvalidOffsetArrayLength => {
|
SparsityPatternFormatError::InvalidOffsetArrayLength => {
|
||||||
write!(f, "Length of offset array is not equal to (major_dim + 1).")
|
write!(f, "Length of offset array is not equal to (major_dim + 1).")
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,17 +1,5 @@
|
||||||
//! Helper functions for sparse matrix computations
|
//! Helper functions for sparse matrix computations
|
||||||
|
|
||||||
/// Check that the first and last offsets conform to the specification of a CSR matrix
|
|
||||||
#[inline]
|
|
||||||
#[must_use]
|
|
||||||
pub fn first_and_last_offsets_are_ok(
|
|
||||||
major_offsets: &Vec<usize>,
|
|
||||||
minor_indices: &Vec<usize>,
|
|
||||||
) -> bool {
|
|
||||||
let first_offset_ok = *major_offsets.first().unwrap() == 0;
|
|
||||||
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
|
|
||||||
return first_offset_ok && last_offset_ok;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// permutes entries of in_slice according to permutation slice and puts them to out_slice
|
/// permutes entries of in_slice according to permutation slice and puts them to out_slice
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn apply_permutation<T: Clone>(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) {
|
pub fn apply_permutation<T: Clone>(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) {
|
||||||
|
|
Loading…
Reference in New Issue