Apply permutation
This commit is contained in:
parent
a2a55cddca
commit
469765a4e5
|
@ -14,6 +14,7 @@ use crate::coo::CooMatrix;
|
||||||
use crate::cs;
|
use crate::cs;
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
|
use crate::utils::apply_permutation;
|
||||||
|
|
||||||
/// Converts a dense matrix to [`CooMatrix`].
|
/// Converts a dense matrix to [`CooMatrix`].
|
||||||
pub fn convert_dense_coo<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CooMatrix<T>
|
pub fn convert_dense_coo<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CooMatrix<T>
|
||||||
|
@ -390,15 +391,6 @@ fn sort_lane<T: Clone>(
|
||||||
apply_permutation(values_result, values, permutation);
|
apply_permutation(values_result, values, permutation);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Move this into `utils` or something?
|
|
||||||
fn apply_permutation<T: Clone>(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) {
|
|
||||||
assert_eq!(out_slice.len(), in_slice.len());
|
|
||||||
assert_eq!(out_slice.len(), permutation.len());
|
|
||||||
for (out_element, old_pos) in out_slice.iter_mut().zip(permutation) {
|
|
||||||
*out_element = in_slice[*old_pos].clone();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Given *sorted* indices and corresponding scalar values, combines duplicates with the given
|
/// Given *sorted* indices and corresponding scalar values, combines duplicates with the given
|
||||||
/// associative combiner and calls the provided produce methods with combined indices and values.
|
/// associative combiner and calls the provided produce methods with combined indices and values.
|
||||||
fn combine_duplicates<T: Clone>(
|
fn combine_duplicates<T: Clone>(
|
||||||
|
|
|
@ -535,6 +535,10 @@ fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseF
|
||||||
use SparsityPatternFormatError::*;
|
use SparsityPatternFormatError::*;
|
||||||
|
|
||||||
match err {
|
match err {
|
||||||
|
DifferentValuesIndicesLengths => E::from_kind_and_msg(
|
||||||
|
K::InvalidStructure,
|
||||||
|
"Lengths of values and column indices are not equal.",
|
||||||
|
),
|
||||||
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||||
K::InvalidStructure,
|
K::InvalidStructure,
|
||||||
"Length of col offset array is not equal to ncols + 1.",
|
"Length of col offset array is not equal to ncols + 1.",
|
||||||
|
|
|
@ -5,11 +5,13 @@
|
||||||
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
|
use crate::utils::{apply_permutation, first_and_last_offsets_are_ok};
|
||||||
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||||
|
|
||||||
use nalgebra::Scalar;
|
use nalgebra::Scalar;
|
||||||
use num_traits::One;
|
use num_traits::One;
|
||||||
|
|
||||||
|
use num_traits::Zero;
|
||||||
use std::slice::{Iter, IterMut};
|
use std::slice::{Iter, IterMut};
|
||||||
|
|
||||||
/// A CSR representation of a sparse matrix.
|
/// A CSR representation of a sparse matrix.
|
||||||
|
@ -184,20 +186,40 @@ impl<T> CsrMatrix<T> {
|
||||||
row_offsets: Vec<usize>,
|
row_offsets: Vec<usize>,
|
||||||
col_indices: Vec<usize>,
|
col_indices: Vec<usize>,
|
||||||
values: Vec<T>,
|
values: Vec<T>,
|
||||||
) -> Result<Self, SparseFormatError> {
|
) -> Result<Self, SparseFormatError>
|
||||||
use nalgebra::base::helper;
|
where
|
||||||
|
T: Scalar + Zero,
|
||||||
|
{
|
||||||
use SparsityPatternFormatError::*;
|
use SparsityPatternFormatError::*;
|
||||||
if helper::first_and_last_offsets_are_ok(&row_offsets, &col_indices) {
|
|
||||||
let mut sorted_col_indices = col_indices.clone();
|
let mut p: Vec<usize> = (0..col_indices.len()).collect();
|
||||||
for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() {
|
|
||||||
sorted_col_indices[offset..row_offsets[index + 1]].sort_unstable();
|
if col_indices.len() != values.len() {
|
||||||
|
return (Err(DifferentValuesIndicesLengths)).map_err(pattern_format_error_to_csr_error);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if first_and_last_offsets_are_ok(&row_offsets, &col_indices) {
|
||||||
|
for (index, &offset) in row_offsets[0..row_offsets.len() - 1].iter().enumerate() {
|
||||||
|
p[offset..row_offsets[index + 1]].sort_by(|a, b| {
|
||||||
|
let x = &col_indices[*a];
|
||||||
|
let y = &col_indices[*b];
|
||||||
|
x.partial_cmp(y).unwrap()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// permute indices
|
||||||
|
let sorted_col_indices: Vec<usize> = p.iter().map(|i| col_indices[*i]).collect();
|
||||||
|
|
||||||
|
// permute values
|
||||||
|
let mut output: Vec<T> = vec![T::zero(); p.len()];
|
||||||
|
apply_permutation(&mut output[..p.len()], &values[..p.len()], &p[..p.len()]);
|
||||||
|
|
||||||
return Self::try_from_csr_data(
|
return Self::try_from_csr_data(
|
||||||
num_rows,
|
num_rows,
|
||||||
num_cols,
|
num_cols,
|
||||||
row_offsets,
|
row_offsets,
|
||||||
sorted_col_indices,
|
sorted_col_indices,
|
||||||
values,
|
output,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return (Err(InvalidOffsetFirstLast)).map_err(pattern_format_error_to_csr_error);
|
return (Err(InvalidOffsetFirstLast)).map_err(pattern_format_error_to_csr_error);
|
||||||
|
@ -568,6 +590,10 @@ fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseF
|
||||||
use SparsityPatternFormatError::*;
|
use SparsityPatternFormatError::*;
|
||||||
|
|
||||||
match err {
|
match err {
|
||||||
|
DifferentValuesIndicesLengths => E::from_kind_and_msg(
|
||||||
|
K::InvalidStructure,
|
||||||
|
"Lengths of values and column indices are not equal.",
|
||||||
|
),
|
||||||
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||||
K::InvalidStructure,
|
K::InvalidStructure,
|
||||||
"Length of row offset array is not equal to nrows + 1.",
|
"Length of row offset array is not equal to nrows + 1.",
|
||||||
|
|
|
@ -149,6 +149,7 @@ pub mod csr;
|
||||||
pub mod factorization;
|
pub mod factorization;
|
||||||
pub mod ops;
|
pub mod ops;
|
||||||
pub mod pattern;
|
pub mod pattern;
|
||||||
|
pub mod utils;
|
||||||
|
|
||||||
pub(crate) mod cs;
|
pub(crate) mod cs;
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
//! Sparsity patterns for CSR and CSC matrices.
|
//! Sparsity patterns for CSR and CSC matrices.
|
||||||
use crate::cs::transpose_cs;
|
use crate::cs::transpose_cs;
|
||||||
|
use crate::utils::first_and_last_offsets_are_ok;
|
||||||
use crate::SparseFormatError;
|
use crate::SparseFormatError;
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
@ -125,7 +126,6 @@ impl SparsityPattern {
|
||||||
major_offsets: Vec<usize>,
|
major_offsets: Vec<usize>,
|
||||||
minor_indices: Vec<usize>,
|
minor_indices: Vec<usize>,
|
||||||
) -> Result<Self, SparsityPatternFormatError> {
|
) -> Result<Self, SparsityPatternFormatError> {
|
||||||
use nalgebra::base::helper;
|
|
||||||
use SparsityPatternFormatError::*;
|
use SparsityPatternFormatError::*;
|
||||||
|
|
||||||
if major_offsets.len() != major_dim + 1 {
|
if major_offsets.len() != major_dim + 1 {
|
||||||
|
@ -133,7 +133,7 @@ impl SparsityPattern {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that the first and last offsets conform to the specification
|
// Check that the first and last offsets conform to the specification
|
||||||
if !helper::first_and_last_offsets_are_ok(&major_offsets, &minor_indices) {
|
if !first_and_last_offsets_are_ok(&major_offsets, &minor_indices) {
|
||||||
return Err(InvalidOffsetFirstLast);
|
return Err(InvalidOffsetFirstLast);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -261,6 +261,8 @@ impl SparsityPattern {
|
||||||
pub enum SparsityPatternFormatError {
|
pub enum SparsityPatternFormatError {
|
||||||
/// Indicates an invalid number of offsets.
|
/// Indicates an invalid number of offsets.
|
||||||
///
|
///
|
||||||
|
/// Indicates that column indices and values have different lengths.
|
||||||
|
DifferentValuesIndicesLengths,
|
||||||
/// The number of offsets must be equal to (major_dim + 1).
|
/// The number of offsets must be equal to (major_dim + 1).
|
||||||
InvalidOffsetArrayLength,
|
InvalidOffsetArrayLength,
|
||||||
/// Indicates that the first or last entry in the offset array did not conform to
|
/// Indicates that the first or last entry in the offset array did not conform to
|
||||||
|
@ -289,7 +291,8 @@ impl From<SparsityPatternFormatError> for SparseFormatError {
|
||||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||||
use SparsityPatternFormatError::*;
|
use SparsityPatternFormatError::*;
|
||||||
match err {
|
match err {
|
||||||
InvalidOffsetArrayLength
|
DifferentValuesIndicesLengths
|
||||||
|
| InvalidOffsetArrayLength
|
||||||
| InvalidOffsetFirstLast
|
| InvalidOffsetFirstLast
|
||||||
| NonmonotonicOffsets
|
| NonmonotonicOffsets
|
||||||
| NonmonotonicMinorIndices => {
|
| NonmonotonicMinorIndices => {
|
||||||
|
@ -310,6 +313,9 @@ impl From<SparsityPatternFormatError> for SparseFormatError {
|
||||||
impl fmt::Display for SparsityPatternFormatError {
|
impl fmt::Display for SparsityPatternFormatError {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
|
SparsityPatternFormatError::DifferentValuesIndicesLengths => {
|
||||||
|
write!(f, "Lengths of values and column indices are not equal.")
|
||||||
|
}
|
||||||
SparsityPatternFormatError::InvalidOffsetArrayLength => {
|
SparsityPatternFormatError::InvalidOffsetArrayLength => {
|
||||||
write!(f, "Length of offset array is not equal to (major_dim + 1).")
|
write!(f, "Length of offset array is not equal to (major_dim + 1).")
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
//! Helper functions for sparse matrix computations
|
||||||
|
|
||||||
|
/// Check that the first and last offsets conform to the specification of a CSR matrix
|
||||||
|
#[inline]
|
||||||
|
#[must_use]
|
||||||
|
pub fn first_and_last_offsets_are_ok(
|
||||||
|
major_offsets: &Vec<usize>,
|
||||||
|
minor_indices: &Vec<usize>,
|
||||||
|
) -> bool {
|
||||||
|
let first_offset_ok = *major_offsets.first().unwrap() == 0;
|
||||||
|
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
|
||||||
|
return first_offset_ok && last_offset_ok;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// permutes entries of in_slice according to permutation slice and puts them to out_slice
|
||||||
|
#[inline]
|
||||||
|
pub fn apply_permutation<T: Clone>(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) {
|
||||||
|
assert_eq!(out_slice.len(), in_slice.len());
|
||||||
|
assert_eq!(out_slice.len(), permutation.len());
|
||||||
|
for (out_element, old_pos) in out_slice.iter_mut().zip(permutation) {
|
||||||
|
*out_element = in_slice[*old_pos].clone();
|
||||||
|
}
|
||||||
|
}
|
|
@ -178,7 +178,7 @@ fn csr_matrix_valid_data_unsorted_column_indices() {
|
||||||
4,
|
4,
|
||||||
vec![0, 1, 2, 5],
|
vec![0, 1, 2, 5],
|
||||||
vec![1, 3, 2, 3, 0],
|
vec![1, 3, 2, 3, 0],
|
||||||
vec![5, 4, 1, 1, 4],
|
vec![5, 4, 1, 4, 1],
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
@ -29,15 +29,3 @@ where
|
||||||
use std::iter;
|
use std::iter;
|
||||||
iter::repeat(()).map(|_| g.gen()).find(f).unwrap()
|
iter::repeat(()).map(|_| g.gen()).find(f).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check that the first and last offsets conform to the specification of a CSR matrix
|
|
||||||
#[inline]
|
|
||||||
#[must_use]
|
|
||||||
pub fn first_and_last_offsets_are_ok(
|
|
||||||
major_offsets: &Vec<usize>,
|
|
||||||
minor_indices: &Vec<usize>,
|
|
||||||
) -> bool {
|
|
||||||
let first_offset_ok = *major_offsets.first().unwrap() == 0;
|
|
||||||
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
|
|
||||||
return first_offset_ok && last_offset_ok;
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue