From 469765a4e59915ac4db5695140177963dd6fb403 Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 8 Oct 2021 00:36:40 +0200 Subject: [PATCH] Apply permutation --- nalgebra-sparse/src/convert/serial.rs | 10 +------ nalgebra-sparse/src/csc.rs | 4 +++ nalgebra-sparse/src/csr.rs | 38 +++++++++++++++++++++---- nalgebra-sparse/src/lib.rs | 1 + nalgebra-sparse/src/pattern.rs | 12 ++++++-- nalgebra-sparse/src/utils.rs | 23 +++++++++++++++ nalgebra-sparse/tests/unit_tests/csr.rs | 2 +- src/base/helper.rs | 12 -------- 8 files changed, 71 insertions(+), 31 deletions(-) create mode 100644 nalgebra-sparse/src/utils.rs diff --git a/nalgebra-sparse/src/convert/serial.rs b/nalgebra-sparse/src/convert/serial.rs index ecbe1dab..219a6bf7 100644 --- a/nalgebra-sparse/src/convert/serial.rs +++ b/nalgebra-sparse/src/convert/serial.rs @@ -14,6 +14,7 @@ use crate::coo::CooMatrix; use crate::cs; use crate::csc::CscMatrix; use crate::csr::CsrMatrix; +use crate::utils::apply_permutation; /// Converts a dense matrix to [`CooMatrix`]. pub fn convert_dense_coo(dense: &Matrix) -> CooMatrix @@ -390,15 +391,6 @@ fn sort_lane( apply_permutation(values_result, values, permutation); } -// TODO: Move this into `utils` or something? -fn apply_permutation(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) { - assert_eq!(out_slice.len(), in_slice.len()); - assert_eq!(out_slice.len(), permutation.len()); - for (out_element, old_pos) in out_slice.iter_mut().zip(permutation) { - *out_element = in_slice[*old_pos].clone(); - } -} - /// Given *sorted* indices and corresponding scalar values, combines duplicates with the given /// associative combiner and calls the provided produce methods with combined indices and values. fn combine_duplicates( diff --git a/nalgebra-sparse/src/csc.rs b/nalgebra-sparse/src/csc.rs index 607cc0cf..b770bbf3 100644 --- a/nalgebra-sparse/src/csc.rs +++ b/nalgebra-sparse/src/csc.rs @@ -535,6 +535,10 @@ fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseF use SparsityPatternFormatError::*; match err { + DifferentValuesIndicesLengths => E::from_kind_and_msg( + K::InvalidStructure, + "Lengths of values and column indices are not equal.", + ), InvalidOffsetArrayLength => E::from_kind_and_msg( K::InvalidStructure, "Length of col offset array is not equal to ncols + 1.", diff --git a/nalgebra-sparse/src/csr.rs b/nalgebra-sparse/src/csr.rs index 101a5fb3..c0bef4d3 100644 --- a/nalgebra-sparse/src/csr.rs +++ b/nalgebra-sparse/src/csr.rs @@ -5,11 +5,13 @@ use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix}; use crate::csc::CscMatrix; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; +use crate::utils::{apply_permutation, first_and_last_offsets_are_ok}; use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind}; use nalgebra::Scalar; use num_traits::One; +use num_traits::Zero; use std::slice::{Iter, IterMut}; /// A CSR representation of a sparse matrix. @@ -184,20 +186,40 @@ impl CsrMatrix { row_offsets: Vec, col_indices: Vec, values: Vec, - ) -> Result { - use nalgebra::base::helper; + ) -> Result + where + T: Scalar + Zero, + { use SparsityPatternFormatError::*; - if helper::first_and_last_offsets_are_ok(&row_offsets, &col_indices) { - let mut sorted_col_indices = col_indices.clone(); + + let mut p: Vec = (0..col_indices.len()).collect(); + + if col_indices.len() != values.len() { + return (Err(DifferentValuesIndicesLengths)).map_err(pattern_format_error_to_csr_error); + } + + if first_and_last_offsets_are_ok(&row_offsets, &col_indices) { for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() { - sorted_col_indices[offset..row_offsets[index + 1]].sort_unstable(); + p[offset..row_offsets[index + 1]].sort_by(|a, b| { + let x = &col_indices[*a]; + let y = &col_indices[*b]; + x.partial_cmp(y).unwrap() + }); } + + // permute indices + let sorted_col_indices: Vec = p.iter().map(|i| col_indices[*i]).collect(); + + // permute values + let mut output: Vec = vec![T::zero(); p.len()]; + apply_permutation(&mut output[..p.len()], &values[..p.len()], &p[..p.len()]); + return Self::try_from_csr_data( num_rows, num_cols, row_offsets, sorted_col_indices, - values, + output, ); } return (Err(InvalidOffsetFirstLast)).map_err(pattern_format_error_to_csr_error); @@ -568,6 +590,10 @@ fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseF use SparsityPatternFormatError::*; match err { + DifferentValuesIndicesLengths => E::from_kind_and_msg( + K::InvalidStructure, + "Lengths of values and column indices are not equal.", + ), InvalidOffsetArrayLength => E::from_kind_and_msg( K::InvalidStructure, "Length of row offset array is not equal to nrows + 1.", diff --git a/nalgebra-sparse/src/lib.rs b/nalgebra-sparse/src/lib.rs index bf845757..607a1abf 100644 --- a/nalgebra-sparse/src/lib.rs +++ b/nalgebra-sparse/src/lib.rs @@ -149,6 +149,7 @@ pub mod csr; pub mod factorization; pub mod ops; pub mod pattern; +pub mod utils; pub(crate) mod cs; diff --git a/nalgebra-sparse/src/pattern.rs b/nalgebra-sparse/src/pattern.rs index 9fdcad38..3e30ee9b 100644 --- a/nalgebra-sparse/src/pattern.rs +++ b/nalgebra-sparse/src/pattern.rs @@ -1,5 +1,6 @@ //! Sparsity patterns for CSR and CSC matrices. use crate::cs::transpose_cs; +use crate::utils::first_and_last_offsets_are_ok; use crate::SparseFormatError; use std::error::Error; use std::fmt; @@ -125,7 +126,6 @@ impl SparsityPattern { major_offsets: Vec, minor_indices: Vec, ) -> Result { - use nalgebra::base::helper; use SparsityPatternFormatError::*; if major_offsets.len() != major_dim + 1 { @@ -133,7 +133,7 @@ impl SparsityPattern { } // Check that the first and last offsets conform to the specification - if !helper::first_and_last_offsets_are_ok(&major_offsets, &minor_indices) { + if !first_and_last_offsets_are_ok(&major_offsets, &minor_indices) { return Err(InvalidOffsetFirstLast); } @@ -261,6 +261,8 @@ impl SparsityPattern { pub enum SparsityPatternFormatError { /// Indicates an invalid number of offsets. /// + /// Indicates that column indices and values have different lengths. + DifferentValuesIndicesLengths, /// The number of offsets must be equal to (major_dim + 1). InvalidOffsetArrayLength, /// Indicates that the first or last entry in the offset array did not conform to @@ -289,7 +291,8 @@ impl From for SparseFormatError { use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry; use SparsityPatternFormatError::*; match err { - InvalidOffsetArrayLength + DifferentValuesIndicesLengths + | InvalidOffsetArrayLength | InvalidOffsetFirstLast | NonmonotonicOffsets | NonmonotonicMinorIndices => { @@ -310,6 +313,9 @@ impl From for SparseFormatError { impl fmt::Display for SparsityPatternFormatError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { + SparsityPatternFormatError::DifferentValuesIndicesLengths => { + write!(f, "Lengths of values and column indices are not equal.") + } SparsityPatternFormatError::InvalidOffsetArrayLength => { write!(f, "Length of offset array is not equal to (major_dim + 1).") } diff --git a/nalgebra-sparse/src/utils.rs b/nalgebra-sparse/src/utils.rs new file mode 100644 index 00000000..411e6e0a --- /dev/null +++ b/nalgebra-sparse/src/utils.rs @@ -0,0 +1,23 @@ +//! Helper functions for sparse matrix computations + +/// Check that the first and last offsets conform to the specification of a CSR matrix +#[inline] +#[must_use] +pub fn first_and_last_offsets_are_ok( + major_offsets: &Vec, + minor_indices: &Vec, +) -> bool { + let first_offset_ok = *major_offsets.first().unwrap() == 0; + let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len(); + return first_offset_ok && last_offset_ok; +} + +/// permutes entries of in_slice according to permutation slice and puts them to out_slice +#[inline] +pub fn apply_permutation(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) { + assert_eq!(out_slice.len(), in_slice.len()); + assert_eq!(out_slice.len(), permutation.len()); + for (out_element, old_pos) in out_slice.iter_mut().zip(permutation) { + *out_element = in_slice[*old_pos].clone(); + } +} diff --git a/nalgebra-sparse/tests/unit_tests/csr.rs b/nalgebra-sparse/tests/unit_tests/csr.rs index 73d4dd27..38c2b344 100644 --- a/nalgebra-sparse/tests/unit_tests/csr.rs +++ b/nalgebra-sparse/tests/unit_tests/csr.rs @@ -178,7 +178,7 @@ fn csr_matrix_valid_data_unsorted_column_indices() { 4, vec![0, 1, 2, 5], vec![1, 3, 2, 3, 0], - vec![5, 4, 1, 1, 4], + vec![5, 4, 1, 4, 1], ) .unwrap(); diff --git a/src/base/helper.rs b/src/base/helper.rs index f596c955..00bd462c 100644 --- a/src/base/helper.rs +++ b/src/base/helper.rs @@ -29,15 +29,3 @@ where use std::iter; iter::repeat(()).map(|_| g.gen()).find(f).unwrap() } - -/// Check that the first and last offsets conform to the specification of a CSR matrix -#[inline] -#[must_use] -pub fn first_and_last_offsets_are_ok( - major_offsets: &Vec, - minor_indices: &Vec, -) -> bool { - let first_offset_ok = *major_offsets.first().unwrap() == 0; - let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len(); - return first_offset_ok && last_offset_ok; -}