From a2a55cddcaa2c2f0d2bb7a82fc394e1ebc5109c7 Mon Sep 17 00:00:00 2001 From: Anton Date: Mon, 4 Oct 2021 20:17:27 +0200 Subject: [PATCH] Check first and last offsets before sorting column indices --- nalgebra-sparse/src/csr.rs | 42 +++++++++++++++++++--------------- nalgebra-sparse/src/pattern.rs | 9 +++----- src/base/helper.rs | 12 ++++++++++ 3 files changed, 39 insertions(+), 24 deletions(-) diff --git a/nalgebra-sparse/src/csr.rs b/nalgebra-sparse/src/csr.rs index fd39fc75..101a5fb3 100644 --- a/nalgebra-sparse/src/csr.rs +++ b/nalgebra-sparse/src/csr.rs @@ -170,7 +170,14 @@ impl CsrMatrix { Self::try_from_pattern_and_values(pattern, values) } - /// Try to construct a CSR matrix from raw CSR data with unsorted columns. + /// Try to construct a CSR matrix from raw CSR data with unsorted column indices. + /// + /// It is assumed that each row contains unique column indices that are in + /// bounds with respect to the number of columns in the matrix. If this is not the case, + /// an error is returned to indicate the failure. + /// + /// An error is returned if the data given does not conform to the CSR storage format. + /// See the documentation for [CsrMatrix](struct.CsrMatrix.html) for more information. pub fn try_from_unsorted_csr_data( num_rows: usize, num_cols: usize, @@ -178,14 +185,22 @@ impl CsrMatrix { col_indices: Vec, values: Vec, ) -> Result { - let sorted_num_cols: Vec = row_offsets[0..row_offsets.len() - 1] - .iter() - .enumerate() - .flat_map(|(index, &offset)| { - Self::sorted(col_indices[offset..row_offsets[index + 1]].to_vec()) - }) - .collect(); - return Self::try_from_csr_data(num_rows, num_cols, row_offsets, sorted_num_cols, values); + use nalgebra::base::helper; + use SparsityPatternFormatError::*; + if helper::first_and_last_offsets_are_ok(&row_offsets, &col_indices) { + let mut sorted_col_indices = col_indices.clone(); + for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() { + sorted_col_indices[offset..row_offsets[index + 1]].sort_unstable(); + } + return Self::try_from_csr_data( + num_rows, + num_cols, + row_offsets, + sorted_col_indices, + values, + ); + } + return (Err(InvalidOffsetFirstLast)).map_err(pattern_format_error_to_csr_error); } /// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values. @@ -208,15 +223,6 @@ impl CsrMatrix { } } - /// Return sorted vector. - #[inline] - #[must_use] - pub fn sorted(row_offsets: Vec) -> Vec { - let mut sorted = row_offsets.clone(); - sorted.sort(); - return sorted; - } - /// The number of rows in the matrix. #[inline] #[must_use] diff --git a/nalgebra-sparse/src/pattern.rs b/nalgebra-sparse/src/pattern.rs index 85f6bc1a..9fdcad38 100644 --- a/nalgebra-sparse/src/pattern.rs +++ b/nalgebra-sparse/src/pattern.rs @@ -125,6 +125,7 @@ impl SparsityPattern { major_offsets: Vec, minor_indices: Vec, ) -> Result { + use nalgebra::base::helper; use SparsityPatternFormatError::*; if major_offsets.len() != major_dim + 1 { @@ -132,12 +133,8 @@ impl SparsityPattern { } // Check that the first and last offsets conform to the specification - { - let first_offset_ok = *major_offsets.first().unwrap() == 0; - let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len(); - if !first_offset_ok || !last_offset_ok { - return Err(InvalidOffsetFirstLast); - } + if !helper::first_and_last_offsets_are_ok(&major_offsets, &minor_indices) { + return Err(InvalidOffsetFirstLast); } // Test that each lane has strictly monotonically increasing minor indices, i.e. diff --git a/src/base/helper.rs b/src/base/helper.rs index 00bd462c..f596c955 100644 --- a/src/base/helper.rs +++ b/src/base/helper.rs @@ -29,3 +29,15 @@ where use std::iter; iter::repeat(()).map(|_| g.gen()).find(f).unwrap() } + +/// Check that the first and last offsets conform to the specification of a CSR matrix +#[inline] +#[must_use] +pub fn first_and_last_offsets_are_ok( + major_offsets: &Vec, + minor_indices: &Vec, +) -> bool { + let first_offset_ok = *major_offsets.first().unwrap() == 0; + let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len(); + return first_offset_ok && last_offset_ok; +}