forked from M-Labs/nalgebra
Apply permutation
This commit is contained in:
parent
a2a55cddca
commit
469765a4e5
@ -14,6 +14,7 @@ use crate::coo::CooMatrix;
|
||||
use crate::cs;
|
||||
use crate::csc::CscMatrix;
|
||||
use crate::csr::CsrMatrix;
|
||||
use crate::utils::apply_permutation;
|
||||
|
||||
/// Converts a dense matrix to [`CooMatrix`].
|
||||
pub fn convert_dense_coo<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CooMatrix<T>
|
||||
@ -390,15 +391,6 @@ fn sort_lane<T: Clone>(
|
||||
apply_permutation(values_result, values, permutation);
|
||||
}
|
||||
|
||||
// TODO: Move this into `utils` or something?
|
||||
fn apply_permutation<T: Clone>(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) {
|
||||
assert_eq!(out_slice.len(), in_slice.len());
|
||||
assert_eq!(out_slice.len(), permutation.len());
|
||||
for (out_element, old_pos) in out_slice.iter_mut().zip(permutation) {
|
||||
*out_element = in_slice[*old_pos].clone();
|
||||
}
|
||||
}
|
||||
|
||||
/// Given *sorted* indices and corresponding scalar values, combines duplicates with the given
|
||||
/// associative combiner and calls the provided produce methods with combined indices and values.
|
||||
fn combine_duplicates<T: Clone>(
|
||||
|
@ -535,6 +535,10 @@ fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseF
|
||||
use SparsityPatternFormatError::*;
|
||||
|
||||
match err {
|
||||
DifferentValuesIndicesLengths => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Lengths of values and column indices are not equal.",
|
||||
),
|
||||
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Length of col offset array is not equal to ncols + 1.",
|
||||
|
@ -5,11 +5,13 @@
|
||||
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||
use crate::csc::CscMatrix;
|
||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||
use crate::utils::{apply_permutation, first_and_last_offsets_are_ok};
|
||||
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||
|
||||
use nalgebra::Scalar;
|
||||
use num_traits::One;
|
||||
|
||||
use num_traits::Zero;
|
||||
use std::slice::{Iter, IterMut};
|
||||
|
||||
/// A CSR representation of a sparse matrix.
|
||||
@ -184,20 +186,40 @@ impl<T> CsrMatrix<T> {
|
||||
row_offsets: Vec<usize>,
|
||||
col_indices: Vec<usize>,
|
||||
values: Vec<T>,
|
||||
) -> Result<Self, SparseFormatError> {
|
||||
use nalgebra::base::helper;
|
||||
) -> Result<Self, SparseFormatError>
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
{
|
||||
use SparsityPatternFormatError::*;
|
||||
if helper::first_and_last_offsets_are_ok(&row_offsets, &col_indices) {
|
||||
let mut sorted_col_indices = col_indices.clone();
|
||||
for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() {
|
||||
sorted_col_indices[offset..row_offsets[index + 1]].sort_unstable();
|
||||
|
||||
let mut p: Vec<usize> = (0..col_indices.len()).collect();
|
||||
|
||||
if col_indices.len() != values.len() {
|
||||
return (Err(DifferentValuesIndicesLengths)).map_err(pattern_format_error_to_csr_error);
|
||||
}
|
||||
|
||||
if first_and_last_offsets_are_ok(&row_offsets, &col_indices) {
|
||||
for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() {
|
||||
p[offset..row_offsets[index + 1]].sort_by(|a, b| {
|
||||
let x = &col_indices[*a];
|
||||
let y = &col_indices[*b];
|
||||
x.partial_cmp(y).unwrap()
|
||||
});
|
||||
}
|
||||
|
||||
// permute indices
|
||||
let sorted_col_indices: Vec<usize> = p.iter().map(|i| col_indices[*i]).collect();
|
||||
|
||||
// permute values
|
||||
let mut output: Vec<T> = vec![T::zero(); p.len()];
|
||||
apply_permutation(&mut output[..p.len()], &values[..p.len()], &p[..p.len()]);
|
||||
|
||||
return Self::try_from_csr_data(
|
||||
num_rows,
|
||||
num_cols,
|
||||
row_offsets,
|
||||
sorted_col_indices,
|
||||
values,
|
||||
output,
|
||||
);
|
||||
}
|
||||
return (Err(InvalidOffsetFirstLast)).map_err(pattern_format_error_to_csr_error);
|
||||
@ -568,6 +590,10 @@ fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseF
|
||||
use SparsityPatternFormatError::*;
|
||||
|
||||
match err {
|
||||
DifferentValuesIndicesLengths => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Lengths of values and column indices are not equal.",
|
||||
),
|
||||
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Length of row offset array is not equal to nrows + 1.",
|
||||
|
@ -149,6 +149,7 @@ pub mod csr;
|
||||
pub mod factorization;
|
||||
pub mod ops;
|
||||
pub mod pattern;
|
||||
pub mod utils;
|
||||
|
||||
pub(crate) mod cs;
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
//! Sparsity patterns for CSR and CSC matrices.
|
||||
use crate::cs::transpose_cs;
|
||||
use crate::utils::first_and_last_offsets_are_ok;
|
||||
use crate::SparseFormatError;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
@ -125,7 +126,6 @@ impl SparsityPattern {
|
||||
major_offsets: Vec<usize>,
|
||||
minor_indices: Vec<usize>,
|
||||
) -> Result<Self, SparsityPatternFormatError> {
|
||||
use nalgebra::base::helper;
|
||||
use SparsityPatternFormatError::*;
|
||||
|
||||
if major_offsets.len() != major_dim + 1 {
|
||||
@ -133,7 +133,7 @@ impl SparsityPattern {
|
||||
}
|
||||
|
||||
// Check that the first and last offsets conform to the specification
|
||||
if !helper::first_and_last_offsets_are_ok(&major_offsets, &minor_indices) {
|
||||
if !first_and_last_offsets_are_ok(&major_offsets, &minor_indices) {
|
||||
return Err(InvalidOffsetFirstLast);
|
||||
}
|
||||
|
||||
@ -261,6 +261,8 @@ impl SparsityPattern {
|
||||
pub enum SparsityPatternFormatError {
|
||||
/// Indicates an invalid number of offsets.
|
||||
///
|
||||
/// Indicates that column indices and values have different lengths.
|
||||
DifferentValuesIndicesLengths,
|
||||
/// The number of offsets must be equal to (major_dim + 1).
|
||||
InvalidOffsetArrayLength,
|
||||
/// Indicates that the first or last entry in the offset array did not conform to
|
||||
@ -289,7 +291,8 @@ impl From<SparsityPatternFormatError> for SparseFormatError {
|
||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||
use SparsityPatternFormatError::*;
|
||||
match err {
|
||||
InvalidOffsetArrayLength
|
||||
DifferentValuesIndicesLengths
|
||||
| InvalidOffsetArrayLength
|
||||
| InvalidOffsetFirstLast
|
||||
| NonmonotonicOffsets
|
||||
| NonmonotonicMinorIndices => {
|
||||
@ -310,6 +313,9 @@ impl From<SparsityPatternFormatError> for SparseFormatError {
|
||||
impl fmt::Display for SparsityPatternFormatError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
SparsityPatternFormatError::DifferentValuesIndicesLengths => {
|
||||
write!(f, "Lengths of values and column indices are not equal.")
|
||||
}
|
||||
SparsityPatternFormatError::InvalidOffsetArrayLength => {
|
||||
write!(f, "Length of offset array is not equal to (major_dim + 1).")
|
||||
}
|
||||
|
23
nalgebra-sparse/src/utils.rs
Normal file
23
nalgebra-sparse/src/utils.rs
Normal file
@ -0,0 +1,23 @@
|
||||
//! Helper functions for sparse matrix computations
|
||||
|
||||
/// Check that the first and last offsets conform to the specification of a CSR matrix
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn first_and_last_offsets_are_ok(
|
||||
major_offsets: &Vec<usize>,
|
||||
minor_indices: &Vec<usize>,
|
||||
) -> bool {
|
||||
let first_offset_ok = *major_offsets.first().unwrap() == 0;
|
||||
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
|
||||
return first_offset_ok && last_offset_ok;
|
||||
}
|
||||
|
||||
/// permutes entries of in_slice according to permutation slice and puts them to out_slice
|
||||
#[inline]
|
||||
pub fn apply_permutation<T: Clone>(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) {
|
||||
assert_eq!(out_slice.len(), in_slice.len());
|
||||
assert_eq!(out_slice.len(), permutation.len());
|
||||
for (out_element, old_pos) in out_slice.iter_mut().zip(permutation) {
|
||||
*out_element = in_slice[*old_pos].clone();
|
||||
}
|
||||
}
|
@ -178,7 +178,7 @@ fn csr_matrix_valid_data_unsorted_column_indices() {
|
||||
4,
|
||||
vec![0, 1, 2, 5],
|
||||
vec![1, 3, 2, 3, 0],
|
||||
vec![5, 4, 1, 1, 4],
|
||||
vec![5, 4, 1, 4, 1],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
@ -29,15 +29,3 @@ where
|
||||
use std::iter;
|
||||
iter::repeat(()).map(|_| g.gen()).find(f).unwrap()
|
||||
}
|
||||
|
||||
/// Check that the first and last offsets conform to the specification of a CSR matrix
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn first_and_last_offsets_are_ok(
|
||||
major_offsets: &Vec<usize>,
|
||||
minor_indices: &Vec<usize>,
|
||||
) -> bool {
|
||||
let first_offset_ok = *major_offsets.first().unwrap() == 0;
|
||||
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
|
||||
return first_offset_ok && last_offset_ok;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user