Check first and last offsets before sorting column indices

This commit is contained in:
Anton 2021-10-04 20:17:27 +02:00
parent 9e85c9e2b6
commit a2a55cddca
3 changed files with 39 additions and 24 deletions

View File

@ -170,7 +170,14 @@ impl<T> CsrMatrix<T> {
Self::try_from_pattern_and_values(pattern, values) Self::try_from_pattern_and_values(pattern, values)
} }
/// Try to construct a CSR matrix from raw CSR data with unsorted columns. /// Try to construct a CSR matrix from raw CSR data with unsorted column indices.
///
/// It is assumed that each row contains unique column indices that are in
/// bounds with respect to the number of columns in the matrix. If this is not the case,
/// an error is returned to indicate the failure.
///
/// An error is returned if the data given does not conform to the CSR storage format.
/// See the documentation for [CsrMatrix](struct.CsrMatrix.html) for more information.
pub fn try_from_unsorted_csr_data( pub fn try_from_unsorted_csr_data(
num_rows: usize, num_rows: usize,
num_cols: usize, num_cols: usize,
@ -178,14 +185,22 @@ impl<T> CsrMatrix<T> {
col_indices: Vec<usize>, col_indices: Vec<usize>,
values: Vec<T>, values: Vec<T>,
) -> Result<Self, SparseFormatError> { ) -> Result<Self, SparseFormatError> {
let sorted_num_cols: Vec<usize> = row_offsets[0..row_offsets.len() - 1] use nalgebra::base::helper;
.iter() use SparsityPatternFormatError::*;
.enumerate() if helper::first_and_last_offsets_are_ok(&row_offsets, &col_indices) {
.flat_map(|(index, &offset)| { let mut sorted_col_indices = col_indices.clone();
Self::sorted(col_indices[offset..row_offsets[index + 1]].to_vec()) for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() {
}) sorted_col_indices[offset..row_offsets[index + 1]].sort_unstable();
.collect(); }
return Self::try_from_csr_data(num_rows, num_cols, row_offsets, sorted_num_cols, values); return Self::try_from_csr_data(
num_rows,
num_cols,
row_offsets,
sorted_col_indices,
values,
);
}
return (Err(InvalidOffsetFirstLast)).map_err(pattern_format_error_to_csr_error);
} }
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values. /// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
@ -208,15 +223,6 @@ impl<T> CsrMatrix<T> {
} }
} }
/// Return sorted vector.
#[inline]
#[must_use]
pub fn sorted(row_offsets: Vec<usize>) -> Vec<usize> {
let mut sorted = row_offsets.clone();
sorted.sort();
return sorted;
}
/// The number of rows in the matrix. /// The number of rows in the matrix.
#[inline] #[inline]
#[must_use] #[must_use]

View File

@ -125,6 +125,7 @@ impl SparsityPattern {
major_offsets: Vec<usize>, major_offsets: Vec<usize>,
minor_indices: Vec<usize>, minor_indices: Vec<usize>,
) -> Result<Self, SparsityPatternFormatError> { ) -> Result<Self, SparsityPatternFormatError> {
use nalgebra::base::helper;
use SparsityPatternFormatError::*; use SparsityPatternFormatError::*;
if major_offsets.len() != major_dim + 1 { if major_offsets.len() != major_dim + 1 {
@ -132,12 +133,8 @@ impl SparsityPattern {
} }
// Check that the first and last offsets conform to the specification // Check that the first and last offsets conform to the specification
{ if !helper::first_and_last_offsets_are_ok(&major_offsets, &minor_indices) {
let first_offset_ok = *major_offsets.first().unwrap() == 0; return Err(InvalidOffsetFirstLast);
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
if !first_offset_ok || !last_offset_ok {
return Err(InvalidOffsetFirstLast);
}
} }
// Test that each lane has strictly monotonically increasing minor indices, i.e. // Test that each lane has strictly monotonically increasing minor indices, i.e.

View File

@ -29,3 +29,15 @@ where
use std::iter; use std::iter;
iter::repeat(()).map(|_| g.gen()).find(f).unwrap() iter::repeat(()).map(|_| g.gen()).find(f).unwrap()
} }
/// Check that the first and last offsets conform to the specification of a CSR matrix
#[inline]
#[must_use]
pub fn first_and_last_offsets_are_ok(
major_offsets: &Vec<usize>,
minor_indices: &Vec<usize>,
) -> bool {
let first_offset_ok = *major_offsets.first().unwrap() == 0;
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
return first_offset_ok && last_offset_ok;
}