Apply permutation

This commit is contained in:
Anton 2021-10-08 00:36:40 +02:00
parent a2a55cddca
commit 469765a4e5
8 changed files with 71 additions and 31 deletions

View File

@ -14,6 +14,7 @@ use crate::coo::CooMatrix;
use crate::cs; use crate::cs;
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use crate::csr::CsrMatrix; use crate::csr::CsrMatrix;
use crate::utils::apply_permutation;
/// Converts a dense matrix to [`CooMatrix`]. /// Converts a dense matrix to [`CooMatrix`].
pub fn convert_dense_coo<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CooMatrix<T> pub fn convert_dense_coo<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CooMatrix<T>
@ -390,15 +391,6 @@ fn sort_lane<T: Clone>(
apply_permutation(values_result, values, permutation); apply_permutation(values_result, values, permutation);
} }
// TODO: Move this into `utils` or something?
fn apply_permutation<T: Clone>(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) {
assert_eq!(out_slice.len(), in_slice.len());
assert_eq!(out_slice.len(), permutation.len());
for (out_element, old_pos) in out_slice.iter_mut().zip(permutation) {
*out_element = in_slice[*old_pos].clone();
}
}
/// Given *sorted* indices and corresponding scalar values, combines duplicates with the given /// Given *sorted* indices and corresponding scalar values, combines duplicates with the given
/// associative combiner and calls the provided produce methods with combined indices and values. /// associative combiner and calls the provided produce methods with combined indices and values.
fn combine_duplicates<T: Clone>( fn combine_duplicates<T: Clone>(

View File

@ -535,6 +535,10 @@ fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseF
use SparsityPatternFormatError::*; use SparsityPatternFormatError::*;
match err { match err {
DifferentValuesIndicesLengths => E::from_kind_and_msg(
K::InvalidStructure,
"Lengths of values and column indices are not equal.",
),
InvalidOffsetArrayLength => E::from_kind_and_msg( InvalidOffsetArrayLength => E::from_kind_and_msg(
K::InvalidStructure, K::InvalidStructure,
"Length of col offset array is not equal to ncols + 1.", "Length of col offset array is not equal to ncols + 1.",

View File

@ -5,11 +5,13 @@
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix}; use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::utils::{apply_permutation, first_and_last_offsets_are_ok};
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind}; use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
use nalgebra::Scalar; use nalgebra::Scalar;
use num_traits::One; use num_traits::One;
use num_traits::Zero;
use std::slice::{Iter, IterMut}; use std::slice::{Iter, IterMut};
/// A CSR representation of a sparse matrix. /// A CSR representation of a sparse matrix.
@ -184,20 +186,40 @@ impl<T> CsrMatrix<T> {
row_offsets: Vec<usize>, row_offsets: Vec<usize>,
col_indices: Vec<usize>, col_indices: Vec<usize>,
values: Vec<T>, values: Vec<T>,
) -> Result<Self, SparseFormatError> { ) -> Result<Self, SparseFormatError>
use nalgebra::base::helper; where
T: Scalar + Zero,
{
use SparsityPatternFormatError::*; use SparsityPatternFormatError::*;
if helper::first_and_last_offsets_are_ok(&row_offsets, &col_indices) {
let mut sorted_col_indices = col_indices.clone(); let mut p: Vec<usize> = (0..col_indices.len()).collect();
if col_indices.len() != values.len() {
return (Err(DifferentValuesIndicesLengths)).map_err(pattern_format_error_to_csr_error);
}
if first_and_last_offsets_are_ok(&row_offsets, &col_indices) {
for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() { for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() {
sorted_col_indices[offset..row_offsets[index + 1]].sort_unstable(); p[offset..row_offsets[index + 1]].sort_by(|a, b| {
let x = &col_indices[*a];
let y = &col_indices[*b];
x.partial_cmp(y).unwrap()
});
} }
// permute indices
let sorted_col_indices: Vec<usize> = p.iter().map(|i| col_indices[*i]).collect();
// permute values
let mut output: Vec<T> = vec![T::zero(); p.len()];
apply_permutation(&mut output[..p.len()], &values[..p.len()], &p[..p.len()]);
return Self::try_from_csr_data( return Self::try_from_csr_data(
num_rows, num_rows,
num_cols, num_cols,
row_offsets, row_offsets,
sorted_col_indices, sorted_col_indices,
values, output,
); );
} }
return (Err(InvalidOffsetFirstLast)).map_err(pattern_format_error_to_csr_error); return (Err(InvalidOffsetFirstLast)).map_err(pattern_format_error_to_csr_error);
@ -568,6 +590,10 @@ fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseF
use SparsityPatternFormatError::*; use SparsityPatternFormatError::*;
match err { match err {
DifferentValuesIndicesLengths => E::from_kind_and_msg(
K::InvalidStructure,
"Lengths of values and column indices are not equal.",
),
InvalidOffsetArrayLength => E::from_kind_and_msg( InvalidOffsetArrayLength => E::from_kind_and_msg(
K::InvalidStructure, K::InvalidStructure,
"Length of row offset array is not equal to nrows + 1.", "Length of row offset array is not equal to nrows + 1.",

View File

@ -149,6 +149,7 @@ pub mod csr;
pub mod factorization; pub mod factorization;
pub mod ops; pub mod ops;
pub mod pattern; pub mod pattern;
pub mod utils;
pub(crate) mod cs; pub(crate) mod cs;

View File

@ -1,5 +1,6 @@
//! Sparsity patterns for CSR and CSC matrices. //! Sparsity patterns for CSR and CSC matrices.
use crate::cs::transpose_cs; use crate::cs::transpose_cs;
use crate::utils::first_and_last_offsets_are_ok;
use crate::SparseFormatError; use crate::SparseFormatError;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
@ -125,7 +126,6 @@ impl SparsityPattern {
major_offsets: Vec<usize>, major_offsets: Vec<usize>,
minor_indices: Vec<usize>, minor_indices: Vec<usize>,
) -> Result<Self, SparsityPatternFormatError> { ) -> Result<Self, SparsityPatternFormatError> {
use nalgebra::base::helper;
use SparsityPatternFormatError::*; use SparsityPatternFormatError::*;
if major_offsets.len() != major_dim + 1 { if major_offsets.len() != major_dim + 1 {
@ -133,7 +133,7 @@ impl SparsityPattern {
} }
// Check that the first and last offsets conform to the specification // Check that the first and last offsets conform to the specification
if !helper::first_and_last_offsets_are_ok(&major_offsets, &minor_indices) { if !first_and_last_offsets_are_ok(&major_offsets, &minor_indices) {
return Err(InvalidOffsetFirstLast); return Err(InvalidOffsetFirstLast);
} }
@ -261,6 +261,8 @@ impl SparsityPattern {
pub enum SparsityPatternFormatError { pub enum SparsityPatternFormatError {
/// Indicates an invalid number of offsets. /// Indicates an invalid number of offsets.
/// ///
/// Indicates that column indices and values have different lengths.
DifferentValuesIndicesLengths,
/// The number of offsets must be equal to (major_dim + 1). /// The number of offsets must be equal to (major_dim + 1).
InvalidOffsetArrayLength, InvalidOffsetArrayLength,
/// Indicates that the first or last entry in the offset array did not conform to /// Indicates that the first or last entry in the offset array did not conform to
@ -289,7 +291,8 @@ impl From<SparsityPatternFormatError> for SparseFormatError {
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry; use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparsityPatternFormatError::*; use SparsityPatternFormatError::*;
match err { match err {
InvalidOffsetArrayLength DifferentValuesIndicesLengths
| InvalidOffsetArrayLength
| InvalidOffsetFirstLast | InvalidOffsetFirstLast
| NonmonotonicOffsets | NonmonotonicOffsets
| NonmonotonicMinorIndices => { | NonmonotonicMinorIndices => {
@ -310,6 +313,9 @@ impl From<SparsityPatternFormatError> for SparseFormatError {
impl fmt::Display for SparsityPatternFormatError { impl fmt::Display for SparsityPatternFormatError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
SparsityPatternFormatError::DifferentValuesIndicesLengths => {
write!(f, "Lengths of values and column indices are not equal.")
}
SparsityPatternFormatError::InvalidOffsetArrayLength => { SparsityPatternFormatError::InvalidOffsetArrayLength => {
write!(f, "Length of offset array is not equal to (major_dim + 1).") write!(f, "Length of offset array is not equal to (major_dim + 1).")
} }

View File

@ -0,0 +1,23 @@
//! Helper functions for sparse matrix computations
/// Check that the first and last offsets conform to the specification of a CSR matrix
#[inline]
#[must_use]
pub fn first_and_last_offsets_are_ok(
major_offsets: &Vec<usize>,
minor_indices: &Vec<usize>,
) -> bool {
let first_offset_ok = *major_offsets.first().unwrap() == 0;
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
return first_offset_ok && last_offset_ok;
}
/// permutes entries of in_slice according to permutation slice and puts them to out_slice
#[inline]
pub fn apply_permutation<T: Clone>(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) {
assert_eq!(out_slice.len(), in_slice.len());
assert_eq!(out_slice.len(), permutation.len());
for (out_element, old_pos) in out_slice.iter_mut().zip(permutation) {
*out_element = in_slice[*old_pos].clone();
}
}

View File

@ -178,7 +178,7 @@ fn csr_matrix_valid_data_unsorted_column_indices() {
4, 4,
vec![0, 1, 2, 5], vec![0, 1, 2, 5],
vec![1, 3, 2, 3, 0], vec![1, 3, 2, 3, 0],
vec![5, 4, 1, 1, 4], vec![5, 4, 1, 4, 1],
) )
.unwrap(); .unwrap();

View File

@ -29,15 +29,3 @@ where
use std::iter; use std::iter;
iter::repeat(()).map(|_| g.gen()).find(f).unwrap() iter::repeat(()).map(|_| g.gen()).find(f).unwrap()
} }
/// Check that the first and last offsets conform to the specification of a CSR matrix
#[inline]
#[must_use]
pub fn first_and_last_offsets_are_ok(
major_offsets: &Vec<usize>,
minor_indices: &Vec<usize>,
) -> bool {
let first_offset_ok = *major_offsets.first().unwrap() == 0;
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
return first_offset_ok && last_offset_ok;
}