2020-12-22 18:01:50 +08:00
|
|
|
use std::mem::replace;
|
2021-01-19 21:15:19 +08:00
|
|
|
use std::ops::Range;
|
|
|
|
|
2021-01-11 22:03:58 +08:00
|
|
|
use num_traits::One;
|
2021-01-19 21:15:19 +08:00
|
|
|
|
2021-01-11 22:03:58 +08:00
|
|
|
use nalgebra::Scalar;
|
2020-12-22 17:19:17 +08:00
|
|
|
|
2021-01-19 21:15:19 +08:00
|
|
|
use crate::pattern::SparsityPattern;
|
2022-03-03 17:14:16 +08:00
|
|
|
use crate::utils::{apply_permutation, compute_sort_permutation};
|
|
|
|
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
2021-01-19 21:15:19 +08:00
|
|
|
|
2020-12-22 17:19:17 +08:00
|
|
|
/// An abstract compressed matrix.
|
|
|
|
///
|
|
|
|
/// For the time being, this is only used internally to share implementation between
|
|
|
|
/// CSR and CSC matrices.
|
|
|
|
///
|
|
|
|
/// A CSR matrix is obtained by associating rows with the major dimension, while a CSC matrix
|
|
|
|
/// is obtained by associating columns with the major dimension.
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
|
|
pub struct CsMatrix<T> {
|
2021-01-19 23:53:39 +08:00
|
|
|
sparsity_pattern: SparsityPattern,
|
2021-01-26 00:26:27 +08:00
|
|
|
values: Vec<T>,
|
2020-12-22 17:19:17 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
impl<T> CsMatrix<T> {
|
|
|
|
/// Create a zero matrix with no explicitly stored entries.
|
|
|
|
#[inline]
|
|
|
|
pub fn new(major_dim: usize, minor_dim: usize) -> Self {
|
|
|
|
Self {
|
2021-01-25 23:04:29 +08:00
|
|
|
sparsity_pattern: SparsityPattern::zeros(major_dim, minor_dim),
|
2020-12-22 17:19:17 +08:00
|
|
|
values: vec![],
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[inline]
|
2021-06-07 22:34:03 +08:00
|
|
|
#[must_use]
|
2021-01-19 23:53:39 +08:00
|
|
|
pub fn pattern(&self) -> &SparsityPattern {
|
2020-12-22 17:19:17 +08:00
|
|
|
&self.sparsity_pattern
|
|
|
|
}
|
|
|
|
|
|
|
|
#[inline]
|
2021-06-07 22:34:03 +08:00
|
|
|
#[must_use]
|
2020-12-22 17:19:17 +08:00
|
|
|
pub fn values(&self) -> &[T] {
|
|
|
|
&self.values
|
|
|
|
}
|
|
|
|
|
|
|
|
#[inline]
|
|
|
|
pub fn values_mut(&mut self) -> &mut [T] {
|
|
|
|
&mut self.values
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
|
|
|
#[inline]
|
2021-06-07 22:34:03 +08:00
|
|
|
#[must_use]
|
2020-12-22 17:19:17 +08:00
|
|
|
pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) {
|
2021-01-19 23:53:39 +08:00
|
|
|
let pattern = self.pattern();
|
2021-01-26 00:26:27 +08:00
|
|
|
(
|
|
|
|
pattern.major_offsets(),
|
|
|
|
pattern.minor_indices(),
|
|
|
|
&self.values,
|
|
|
|
)
|
2020-12-22 17:19:17 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
|
|
|
#[inline]
|
|
|
|
pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
2021-01-19 23:53:39 +08:00
|
|
|
let pattern = &mut self.sparsity_pattern;
|
2021-01-26 00:26:27 +08:00
|
|
|
(
|
|
|
|
pattern.major_offsets(),
|
|
|
|
pattern.minor_indices(),
|
|
|
|
&mut self.values,
|
|
|
|
)
|
2020-12-22 17:19:17 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
#[inline]
|
2021-01-19 23:53:39 +08:00
|
|
|
pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
|
2020-12-22 17:19:17 +08:00
|
|
|
(&self.sparsity_pattern, &mut self.values)
|
|
|
|
}
|
|
|
|
|
|
|
|
#[inline]
|
2021-01-26 00:26:27 +08:00
|
|
|
pub fn from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>) -> Self {
|
|
|
|
assert_eq!(
|
|
|
|
pattern.nnz(),
|
|
|
|
values.len(),
|
|
|
|
"Internal error: consumers should verify shape compatibility."
|
|
|
|
);
|
2020-12-22 17:19:17 +08:00
|
|
|
Self {
|
|
|
|
sparsity_pattern: pattern,
|
|
|
|
values,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Internal method for simplifying access to a lane's data
|
|
|
|
#[inline]
|
2021-06-07 22:34:03 +08:00
|
|
|
#[must_use]
|
2020-12-22 17:19:17 +08:00
|
|
|
pub fn get_index_range(&self, row_index: usize) -> Option<Range<usize>> {
|
|
|
|
let row_begin = *self.sparsity_pattern.major_offsets().get(row_index)?;
|
|
|
|
let row_end = *self.sparsity_pattern.major_offsets().get(row_index + 1)?;
|
2021-01-26 00:26:27 +08:00
|
|
|
Some(row_begin..row_end)
|
2020-12-22 17:19:17 +08:00
|
|
|
}
|
|
|
|
|
2021-01-19 23:53:39 +08:00
|
|
|
pub fn take_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
|
2020-12-22 17:19:17 +08:00
|
|
|
(self.sparsity_pattern, self.values)
|
|
|
|
}
|
|
|
|
|
|
|
|
#[inline]
|
|
|
|
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
2021-01-19 23:53:39 +08:00
|
|
|
let (offsets, indices) = self.sparsity_pattern.disassemble();
|
|
|
|
(offsets, indices, self.values)
|
2020-12-22 17:19:17 +08:00
|
|
|
}
|
|
|
|
|
2021-01-19 23:56:30 +08:00
|
|
|
#[inline]
|
|
|
|
pub fn into_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
|
|
|
|
(self.sparsity_pattern, self.values)
|
|
|
|
}
|
|
|
|
|
2020-12-22 17:19:17 +08:00
|
|
|
/// Returns an entry for the given major/minor indices, or `None` if the indices are out
|
|
|
|
/// of bounds.
|
2021-06-07 22:34:03 +08:00
|
|
|
#[must_use]
|
2021-07-28 07:18:29 +08:00
|
|
|
pub fn get_entry(&self, major_index: usize, minor_index: usize) -> Option<SparseEntry<'_, T>> {
|
2020-12-22 17:19:17 +08:00
|
|
|
let row_range = self.get_index_range(major_index)?;
|
|
|
|
let (_, minor_indices, values) = self.cs_data();
|
|
|
|
let minor_indices = &minor_indices[row_range.clone()];
|
|
|
|
let values = &values[row_range];
|
2021-01-26 00:26:27 +08:00
|
|
|
get_entry_from_slices(
|
|
|
|
self.pattern().minor_dim(),
|
|
|
|
minor_indices,
|
|
|
|
values,
|
|
|
|
minor_index,
|
|
|
|
)
|
2020-12-22 17:19:17 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns a mutable entry for the given major/minor indices, or `None` if the indices are out
|
|
|
|
/// of bounds.
|
2021-01-26 00:26:27 +08:00
|
|
|
pub fn get_entry_mut(
|
|
|
|
&mut self,
|
|
|
|
major_index: usize,
|
|
|
|
minor_index: usize,
|
2021-07-28 07:18:29 +08:00
|
|
|
) -> Option<SparseEntryMut<'_, T>> {
|
2020-12-22 17:19:17 +08:00
|
|
|
let row_range = self.get_index_range(major_index)?;
|
|
|
|
let minor_dim = self.pattern().minor_dim();
|
|
|
|
let (_, minor_indices, values) = self.cs_data_mut();
|
|
|
|
let minor_indices = &minor_indices[row_range.clone()];
|
|
|
|
let values = &mut values[row_range];
|
|
|
|
get_mut_entry_from_slices(minor_dim, minor_indices, values, minor_index)
|
|
|
|
}
|
|
|
|
|
2021-06-07 22:34:03 +08:00
|
|
|
#[must_use]
|
2021-07-28 07:18:29 +08:00
|
|
|
pub fn get_lane(&self, index: usize) -> Option<CsLane<'_, T>> {
|
2020-12-22 17:19:17 +08:00
|
|
|
let range = self.get_index_range(index)?;
|
|
|
|
let (_, minor_indices, values) = self.cs_data();
|
|
|
|
Some(CsLane {
|
|
|
|
minor_indices: &minor_indices[range.clone()],
|
|
|
|
values: &values[range],
|
2021-01-26 00:26:27 +08:00
|
|
|
minor_dim: self.pattern().minor_dim(),
|
2020-12-22 17:19:17 +08:00
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
#[inline]
|
2021-06-07 23:24:43 +08:00
|
|
|
#[must_use]
|
2021-07-28 07:18:29 +08:00
|
|
|
pub fn get_lane_mut(&mut self, index: usize) -> Option<CsLaneMut<'_, T>> {
|
2020-12-22 17:19:17 +08:00
|
|
|
let range = self.get_index_range(index)?;
|
|
|
|
let minor_dim = self.pattern().minor_dim();
|
|
|
|
let (_, minor_indices, values) = self.cs_data_mut();
|
|
|
|
Some(CsLaneMut {
|
|
|
|
minor_dim,
|
|
|
|
minor_indices: &minor_indices[range.clone()],
|
2021-01-26 00:26:27 +08:00
|
|
|
values: &mut values[range],
|
2020-12-22 17:19:17 +08:00
|
|
|
})
|
|
|
|
}
|
2020-12-30 23:09:46 +08:00
|
|
|
|
|
|
|
#[inline]
|
2021-07-28 07:18:29 +08:00
|
|
|
pub fn lane_iter(&self) -> CsLaneIter<'_, T> {
|
2021-01-19 23:53:39 +08:00
|
|
|
CsLaneIter::new(self.pattern(), self.values())
|
2020-12-30 23:09:46 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
#[inline]
|
2021-07-28 07:18:29 +08:00
|
|
|
pub fn lane_iter_mut(&mut self) -> CsLaneIterMut<'_, T> {
|
2021-01-19 23:53:39 +08:00
|
|
|
CsLaneIterMut::new(&self.sparsity_pattern, &mut self.values)
|
2020-12-30 23:09:46 +08:00
|
|
|
}
|
2021-01-15 00:12:08 +08:00
|
|
|
|
|
|
|
#[inline]
|
2021-06-07 22:34:03 +08:00
|
|
|
#[must_use]
|
2021-01-15 00:12:08 +08:00
|
|
|
pub fn filter<P>(&self, predicate: P) -> Self
|
|
|
|
where
|
|
|
|
T: Clone,
|
2021-01-26 00:26:27 +08:00
|
|
|
P: Fn(usize, usize, &T) -> bool,
|
2021-01-15 00:12:08 +08:00
|
|
|
{
|
|
|
|
let (major_dim, minor_dim) = (self.pattern().major_dim(), self.pattern().minor_dim());
|
|
|
|
let mut new_offsets = Vec::with_capacity(self.pattern().major_dim() + 1);
|
|
|
|
let mut new_indices = Vec::new();
|
|
|
|
let mut new_values = Vec::new();
|
|
|
|
|
|
|
|
new_offsets.push(0);
|
|
|
|
for (i, lane) in self.lane_iter().enumerate() {
|
|
|
|
for (&j, value) in lane.minor_indices().iter().zip(lane.values) {
|
|
|
|
if predicate(i, j, value) {
|
|
|
|
new_indices.push(j);
|
|
|
|
new_values.push(value.clone());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
new_offsets.push(new_indices.len());
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: Avoid checks here
|
|
|
|
let new_pattern = SparsityPattern::try_from_offsets_and_indices(
|
|
|
|
major_dim,
|
|
|
|
minor_dim,
|
|
|
|
new_offsets,
|
2021-01-26 00:26:27 +08:00
|
|
|
new_indices,
|
|
|
|
)
|
|
|
|
.expect("Internal error: Sparsity pattern must always be valid.");
|
2021-01-15 00:12:08 +08:00
|
|
|
|
2021-01-19 23:53:39 +08:00
|
|
|
Self::from_pattern_and_values(new_pattern, new_values)
|
2021-01-15 00:12:08 +08:00
|
|
|
}
|
2021-01-25 23:04:29 +08:00
|
|
|
|
|
|
|
/// Returns the diagonal of the matrix as a sparse matrix.
|
2021-06-07 22:34:03 +08:00
|
|
|
#[must_use]
|
2021-01-25 23:04:29 +08:00
|
|
|
pub fn diagonal_as_matrix(&self) -> Self
|
2021-01-26 00:26:27 +08:00
|
|
|
where
|
|
|
|
T: Clone,
|
2021-01-25 23:04:29 +08:00
|
|
|
{
|
|
|
|
// TODO: This might be faster with a binary search for each diagonal entry
|
|
|
|
self.filter(|i, j, _| i == j)
|
|
|
|
}
|
2020-12-22 17:19:17 +08:00
|
|
|
}
|
|
|
|
|
2021-01-11 22:03:58 +08:00
|
|
|
impl<T: Scalar + One> CsMatrix<T> {
|
|
|
|
#[inline]
|
|
|
|
pub fn identity(n: usize) -> Self {
|
2021-01-26 00:26:27 +08:00
|
|
|
let offsets: Vec<_> = (0..=n).collect();
|
|
|
|
let indices: Vec<_> = (0..n).collect();
|
2021-01-11 22:03:58 +08:00
|
|
|
let values = vec![T::one(); n];
|
|
|
|
|
|
|
|
// TODO: We should skip checks here
|
2021-01-26 00:26:27 +08:00
|
|
|
let pattern =
|
|
|
|
SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices).unwrap();
|
2021-01-19 23:53:39 +08:00
|
|
|
Self::from_pattern_and_values(pattern, values)
|
2021-01-11 22:03:58 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-12-30 23:09:46 +08:00
|
|
|
fn get_entry_from_slices<'a, T>(
|
2020-12-22 17:19:17 +08:00
|
|
|
minor_dim: usize,
|
|
|
|
minor_indices: &'a [usize],
|
|
|
|
values: &'a [T],
|
2021-01-26 00:26:27 +08:00
|
|
|
global_minor_index: usize,
|
|
|
|
) -> Option<SparseEntry<'a, T>> {
|
2020-12-22 17:19:17 +08:00
|
|
|
let local_index = minor_indices.binary_search(&global_minor_index);
|
|
|
|
if let Ok(local_index) = local_index {
|
|
|
|
Some(SparseEntry::NonZero(&values[local_index]))
|
|
|
|
} else if global_minor_index < minor_dim {
|
|
|
|
Some(SparseEntry::Zero)
|
|
|
|
} else {
|
|
|
|
None
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-12-30 23:09:46 +08:00
|
|
|
fn get_mut_entry_from_slices<'a, T>(
|
2020-12-22 17:19:17 +08:00
|
|
|
minor_dim: usize,
|
|
|
|
minor_indices: &'a [usize],
|
|
|
|
values: &'a mut [T],
|
2021-01-26 00:26:27 +08:00
|
|
|
global_minor_indices: usize,
|
|
|
|
) -> Option<SparseEntryMut<'a, T>> {
|
2020-12-22 17:19:17 +08:00
|
|
|
let local_index = minor_indices.binary_search(&global_minor_indices);
|
|
|
|
if let Ok(local_index) = local_index {
|
|
|
|
Some(SparseEntryMut::NonZero(&mut values[local_index]))
|
|
|
|
} else if global_minor_indices < minor_dim {
|
|
|
|
Some(SparseEntryMut::Zero)
|
|
|
|
} else {
|
|
|
|
None
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
|
|
pub struct CsLane<'a, T> {
|
2020-12-30 23:09:46 +08:00
|
|
|
minor_dim: usize,
|
|
|
|
minor_indices: &'a [usize],
|
2021-01-26 00:26:27 +08:00
|
|
|
values: &'a [T],
|
2020-12-22 17:19:17 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug, PartialEq, Eq)]
|
|
|
|
pub struct CsLaneMut<'a, T> {
|
2020-12-30 23:09:46 +08:00
|
|
|
minor_dim: usize,
|
|
|
|
minor_indices: &'a [usize],
|
2021-01-26 00:26:27 +08:00
|
|
|
values: &'a mut [T],
|
2020-12-22 17:19:17 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
pub struct CsLaneIter<'a, T> {
|
|
|
|
// The index of the lane that will be returned on the next iteration
|
|
|
|
current_lane_idx: usize,
|
|
|
|
pattern: &'a SparsityPattern,
|
|
|
|
remaining_values: &'a [T],
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<'a, T> CsLaneIter<'a, T> {
|
|
|
|
pub fn new(pattern: &'a SparsityPattern, values: &'a [T]) -> Self {
|
|
|
|
Self {
|
|
|
|
current_lane_idx: 0,
|
|
|
|
pattern,
|
2021-01-26 00:26:27 +08:00
|
|
|
remaining_values: values,
|
2020-12-22 17:19:17 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<'a, T> Iterator for CsLaneIter<'a, T>
|
2021-01-26 00:26:27 +08:00
|
|
|
where
|
|
|
|
T: 'a,
|
2020-12-22 17:19:17 +08:00
|
|
|
{
|
|
|
|
type Item = CsLane<'a, T>;
|
|
|
|
|
|
|
|
fn next(&mut self) -> Option<Self::Item> {
|
|
|
|
let lane = self.pattern.get_lane(self.current_lane_idx);
|
|
|
|
let minor_dim = self.pattern.minor_dim();
|
|
|
|
|
|
|
|
if let Some(minor_indices) = lane {
|
|
|
|
let count = minor_indices.len();
|
|
|
|
let values_in_lane = &self.remaining_values[..count];
|
2021-01-26 00:26:27 +08:00
|
|
|
self.remaining_values = &self.remaining_values[count..];
|
2020-12-22 17:19:17 +08:00
|
|
|
self.current_lane_idx += 1;
|
|
|
|
|
|
|
|
Some(CsLane {
|
|
|
|
minor_dim,
|
|
|
|
minor_indices,
|
2021-01-26 00:26:27 +08:00
|
|
|
values: values_in_lane,
|
2020-12-22 17:19:17 +08:00
|
|
|
})
|
|
|
|
} else {
|
|
|
|
None
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub struct CsLaneIterMut<'a, T> {
|
|
|
|
// The index of the lane that will be returned on the next iteration
|
|
|
|
current_lane_idx: usize,
|
|
|
|
pattern: &'a SparsityPattern,
|
2020-12-22 18:01:50 +08:00
|
|
|
remaining_values: &'a mut [T],
|
2020-12-22 17:19:17 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
impl<'a, T> CsLaneIterMut<'a, T> {
|
|
|
|
pub fn new(pattern: &'a SparsityPattern, values: &'a mut [T]) -> Self {
|
|
|
|
Self {
|
|
|
|
current_lane_idx: 0,
|
|
|
|
pattern,
|
2021-01-26 00:26:27 +08:00
|
|
|
remaining_values: values,
|
2020-12-22 17:19:17 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<'a, T> Iterator for CsLaneIterMut<'a, T>
|
2021-01-26 00:26:27 +08:00
|
|
|
where
|
|
|
|
T: 'a,
|
2020-12-22 17:19:17 +08:00
|
|
|
{
|
|
|
|
type Item = CsLaneMut<'a, T>;
|
|
|
|
|
|
|
|
fn next(&mut self) -> Option<Self::Item> {
|
|
|
|
let lane = self.pattern.get_lane(self.current_lane_idx);
|
|
|
|
let minor_dim = self.pattern.minor_dim();
|
|
|
|
|
|
|
|
if let Some(minor_indices) = lane {
|
|
|
|
let count = minor_indices.len();
|
|
|
|
|
2020-12-22 18:01:50 +08:00
|
|
|
let remaining = replace(&mut self.remaining_values, &mut []);
|
|
|
|
let (values_in_lane, remaining) = remaining.split_at_mut(count);
|
|
|
|
self.remaining_values = remaining;
|
2020-12-22 17:19:17 +08:00
|
|
|
self.current_lane_idx += 1;
|
|
|
|
|
|
|
|
Some(CsLaneMut {
|
|
|
|
minor_dim,
|
|
|
|
minor_indices,
|
2021-01-26 00:26:27 +08:00
|
|
|
values: values_in_lane,
|
2020-12-22 17:19:17 +08:00
|
|
|
})
|
|
|
|
} else {
|
|
|
|
None
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-12-30 23:09:46 +08:00
|
|
|
/// Implement the methods common to both CsLane and CsLaneMut. See the documentation for the
|
|
|
|
/// methods delegated here by CsrMatrix and CscMatrix members for more information.
|
|
|
|
macro_rules! impl_cs_lane_common_methods {
|
|
|
|
($name:ty) => {
|
|
|
|
impl<'a, T> $name {
|
|
|
|
#[inline]
|
2021-06-07 22:34:03 +08:00
|
|
|
#[must_use]
|
2020-12-30 23:09:46 +08:00
|
|
|
pub fn minor_dim(&self) -> usize {
|
|
|
|
self.minor_dim
|
|
|
|
}
|
|
|
|
|
|
|
|
#[inline]
|
2021-06-07 22:34:03 +08:00
|
|
|
#[must_use]
|
2020-12-30 23:09:46 +08:00
|
|
|
pub fn nnz(&self) -> usize {
|
|
|
|
self.minor_indices.len()
|
|
|
|
}
|
|
|
|
|
|
|
|
#[inline]
|
2021-06-07 22:34:03 +08:00
|
|
|
#[must_use]
|
2020-12-30 23:09:46 +08:00
|
|
|
pub fn minor_indices(&self) -> &[usize] {
|
|
|
|
self.minor_indices
|
|
|
|
}
|
|
|
|
|
|
|
|
#[inline]
|
2021-06-07 22:34:03 +08:00
|
|
|
#[must_use]
|
2020-12-30 23:09:46 +08:00
|
|
|
pub fn values(&self) -> &[T] {
|
|
|
|
self.values
|
|
|
|
}
|
|
|
|
|
|
|
|
#[inline]
|
2021-06-07 22:34:03 +08:00
|
|
|
#[must_use]
|
2021-07-28 07:18:29 +08:00
|
|
|
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<'_, T>> {
|
2020-12-30 23:09:46 +08:00
|
|
|
get_entry_from_slices(
|
|
|
|
self.minor_dim,
|
|
|
|
self.minor_indices,
|
|
|
|
self.values,
|
2021-01-26 00:26:27 +08:00
|
|
|
global_col_index,
|
|
|
|
)
|
2020-12-30 23:09:46 +08:00
|
|
|
}
|
|
|
|
}
|
2021-01-26 00:26:27 +08:00
|
|
|
};
|
2020-12-30 23:09:46 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
impl_cs_lane_common_methods!(CsLane<'a, T>);
|
|
|
|
impl_cs_lane_common_methods!(CsLaneMut<'a, T>);
|
|
|
|
|
|
|
|
impl<'a, T> CsLaneMut<'a, T> {
|
|
|
|
pub fn values_mut(&mut self) -> &mut [T] {
|
|
|
|
self.values
|
|
|
|
}
|
2020-12-22 17:19:17 +08:00
|
|
|
|
2020-12-30 23:09:46 +08:00
|
|
|
pub fn indices_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
|
|
|
(self.minor_indices, self.values)
|
|
|
|
}
|
|
|
|
|
2021-06-07 23:24:43 +08:00
|
|
|
#[must_use]
|
2021-07-28 07:18:29 +08:00
|
|
|
pub fn get_entry_mut(&mut self, global_minor_index: usize) -> Option<SparseEntryMut<'_, T>> {
|
2021-01-26 00:26:27 +08:00
|
|
|
get_mut_entry_from_slices(
|
|
|
|
self.minor_dim,
|
|
|
|
self.minor_indices,
|
|
|
|
self.values,
|
|
|
|
global_minor_index,
|
|
|
|
)
|
2020-12-30 23:09:46 +08:00
|
|
|
}
|
2021-01-19 21:15:19 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Helper struct for working with uninitialized data in vectors.
|
|
|
|
/// TODO: This doesn't belong here.
|
|
|
|
struct UninitVec<T> {
|
2021-01-19 22:02:23 +08:00
|
|
|
vec: Vec<T>,
|
2021-01-26 00:26:27 +08:00
|
|
|
len: usize,
|
2021-01-19 21:15:19 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
impl<T> UninitVec<T> {
|
|
|
|
pub fn from_len(len: usize) -> Self {
|
|
|
|
Self {
|
2021-01-19 22:02:23 +08:00
|
|
|
vec: Vec::with_capacity(len),
|
|
|
|
// We need to store len separately, because for zero-sized types,
|
|
|
|
// Vec::with_capacity(len) does not give vec.capacity() == len
|
2021-01-26 00:26:27 +08:00
|
|
|
len,
|
2021-01-19 21:15:19 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Sets the element associated with the given index to the provided value.
|
|
|
|
///
|
|
|
|
/// Must be called exactly once per index, otherwise results in undefined behavior.
|
|
|
|
pub unsafe fn set(&mut self, index: usize, value: T) {
|
|
|
|
self.vec.as_mut_ptr().add(index).write(value)
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Marks the vector data as initialized by returning a full vector.
|
|
|
|
///
|
|
|
|
/// It is undefined behavior to call this function unless *all* elements have been written to
|
|
|
|
/// exactly once.
|
|
|
|
pub unsafe fn assume_init(mut self) -> Vec<T> {
|
2021-01-19 22:02:23 +08:00
|
|
|
self.vec.set_len(self.len);
|
2021-01-19 21:15:19 +08:00
|
|
|
self.vec
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Transposes the compressed format.
|
|
|
|
///
|
|
|
|
/// This means that major and minor roles are switched. This is used for converting between CSR
|
|
|
|
/// and CSC formats.
|
|
|
|
pub fn transpose_cs<T>(
|
2021-01-26 00:26:27 +08:00
|
|
|
major_dim: usize,
|
|
|
|
minor_dim: usize,
|
|
|
|
source_major_offsets: &[usize],
|
|
|
|
source_minor_indices: &[usize],
|
|
|
|
values: &[T],
|
|
|
|
) -> (Vec<usize>, Vec<usize>, Vec<T>)
|
2021-01-19 21:15:19 +08:00
|
|
|
where
|
2021-01-26 00:26:27 +08:00
|
|
|
T: Scalar,
|
2021-01-19 21:15:19 +08:00
|
|
|
{
|
|
|
|
assert_eq!(source_major_offsets.len(), major_dim + 1);
|
|
|
|
assert_eq!(source_minor_indices.len(), values.len());
|
|
|
|
let nnz = values.len();
|
|
|
|
|
|
|
|
// Count the number of occurences of each minor index
|
|
|
|
let mut minor_counts = vec![0; minor_dim];
|
|
|
|
for minor_idx in source_minor_indices {
|
|
|
|
minor_counts[*minor_idx] += 1;
|
|
|
|
}
|
|
|
|
convert_counts_to_offsets(&mut minor_counts);
|
|
|
|
let mut target_offsets = minor_counts;
|
|
|
|
target_offsets.push(nnz);
|
|
|
|
let mut target_indices = vec![usize::MAX; nnz];
|
|
|
|
|
|
|
|
// We have to use uninitialized storage, because we don't have any kind of "default" value
|
|
|
|
// available for `T`. Unfortunately this necessitates some small amount of unsafe code
|
|
|
|
let mut target_values = UninitVec::from_len(nnz);
|
|
|
|
|
|
|
|
// Keep track of how many entries we have placed in each target major lane
|
|
|
|
let mut current_target_major_counts = vec![0; minor_dim];
|
|
|
|
|
2021-01-26 00:26:27 +08:00
|
|
|
for source_major_idx in 0..major_dim {
|
2021-01-19 21:15:19 +08:00
|
|
|
let source_lane_begin = source_major_offsets[source_major_idx];
|
|
|
|
let source_lane_end = source_major_offsets[source_major_idx + 1];
|
2021-01-26 00:26:27 +08:00
|
|
|
let source_lane_indices = &source_minor_indices[source_lane_begin..source_lane_end];
|
|
|
|
let source_lane_values = &values[source_lane_begin..source_lane_end];
|
2021-01-19 21:15:19 +08:00
|
|
|
|
|
|
|
for (&source_minor_idx, val) in source_lane_indices.iter().zip(source_lane_values) {
|
|
|
|
// Compute the offset in the target data for this particular source entry
|
2021-01-26 00:26:27 +08:00
|
|
|
let target_lane_count = &mut current_target_major_counts[source_minor_idx];
|
2021-01-19 21:15:19 +08:00
|
|
|
let entry_offset = target_offsets[source_minor_idx] + *target_lane_count;
|
|
|
|
target_indices[entry_offset] = source_major_idx;
|
2021-01-26 00:26:27 +08:00
|
|
|
unsafe {
|
2021-08-04 23:34:25 +08:00
|
|
|
target_values.set(entry_offset, val.clone());
|
2021-01-26 00:26:27 +08:00
|
|
|
}
|
2021-01-19 21:15:19 +08:00
|
|
|
*target_lane_count += 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// At this point, we should have written to each element in target_values exactly once,
|
|
|
|
// so initialization should be sound
|
|
|
|
let target_values = unsafe { target_values.assume_init() };
|
|
|
|
(target_offsets, target_indices, target_values)
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn convert_counts_to_offsets(counts: &mut [usize]) {
|
|
|
|
// Convert the counts to an offset
|
|
|
|
let mut offset = 0;
|
|
|
|
for i_offset in counts.iter_mut() {
|
|
|
|
let count = *i_offset;
|
|
|
|
*i_offset = offset;
|
|
|
|
offset += count;
|
|
|
|
}
|
|
|
|
}
|
2022-03-03 17:14:16 +08:00
|
|
|
|
|
|
|
/// Validates cs data, optionally sorts minor indices and values
|
|
|
|
pub(crate) fn validate_and_optionally_sort_cs_data<T>(
|
|
|
|
major_dim: usize,
|
|
|
|
minor_dim: usize,
|
|
|
|
major_offsets: &[usize],
|
|
|
|
minor_indices: &mut [usize],
|
|
|
|
values: Option<&mut [T]>,
|
|
|
|
sort: bool,
|
|
|
|
) -> Result<(), SparseFormatError>
|
|
|
|
where
|
|
|
|
T: Scalar,
|
|
|
|
{
|
|
|
|
let mut values_option = values;
|
|
|
|
|
|
|
|
if let Some(values) = values_option.as_mut() {
|
|
|
|
if minor_indices.len() != values.len() {
|
|
|
|
return Err(SparseFormatError::from_kind_and_msg(
|
|
|
|
SparseFormatErrorKind::InvalidStructure,
|
|
|
|
"Number of values and minor indices must be the same.",
|
|
|
|
));
|
|
|
|
}
|
|
|
|
} else if sort {
|
|
|
|
unreachable!("Internal error: Sorting currently not supported if no values are present.");
|
|
|
|
}
|
|
|
|
if major_offsets.len() == 0 {
|
|
|
|
return Err(SparseFormatError::from_kind_and_msg(
|
|
|
|
SparseFormatErrorKind::InvalidStructure,
|
|
|
|
"Number of offsets should be greater than 0.",
|
|
|
|
));
|
|
|
|
}
|
|
|
|
if major_offsets.len() != major_dim + 1 {
|
|
|
|
return Err(SparseFormatError::from_kind_and_msg(
|
|
|
|
SparseFormatErrorKind::InvalidStructure,
|
|
|
|
"Length of offset array is not equal to (major_dim + 1).",
|
|
|
|
));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check that the first and last offsets conform to the specification
|
|
|
|
{
|
|
|
|
let first_offset_ok = *major_offsets.first().unwrap() == 0;
|
|
|
|
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
|
|
|
|
if !first_offset_ok || !last_offset_ok {
|
|
|
|
return Err(SparseFormatError::from_kind_and_msg(
|
|
|
|
SparseFormatErrorKind::InvalidStructure,
|
|
|
|
"First or last offset is incompatible with format.",
|
|
|
|
));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Set up required buffers up front
|
|
|
|
let mut minor_idx_buffer: Vec<usize> = Vec::new();
|
|
|
|
let mut values_buffer: Vec<T> = Vec::new();
|
|
|
|
let mut minor_index_permutation: Vec<usize> = Vec::new();
|
|
|
|
|
|
|
|
// Test that each lane has strictly monotonically increasing minor indices, i.e.
|
|
|
|
// minor indices within a lane are sorted, unique. Sort minor indices within a lane if needed.
|
|
|
|
// In addition, each minor index must be in bounds with respect to the minor dimension.
|
|
|
|
{
|
|
|
|
for lane_idx in 0..major_dim {
|
|
|
|
let range_start = major_offsets[lane_idx];
|
|
|
|
let range_end = major_offsets[lane_idx + 1];
|
|
|
|
|
|
|
|
// Test that major offsets are monotonically increasing
|
|
|
|
if range_start > range_end {
|
|
|
|
return Err(SparseFormatError::from_kind_and_msg(
|
|
|
|
SparseFormatErrorKind::InvalidStructure,
|
|
|
|
"Offsets are not monotonically increasing.",
|
|
|
|
));
|
|
|
|
}
|
|
|
|
|
|
|
|
let minor_idx_in_lane = minor_indices.get(range_start..range_end).ok_or(
|
|
|
|
SparseFormatError::from_kind_and_msg(
|
|
|
|
SparseFormatErrorKind::IndexOutOfBounds,
|
|
|
|
"A major offset is out of bounds.",
|
|
|
|
),
|
|
|
|
)?;
|
|
|
|
|
|
|
|
// We test for in-bounds, uniqueness and monotonicity at the same time
|
|
|
|
// to ensure that we only visit each minor index once
|
|
|
|
let mut prev = None;
|
|
|
|
let mut monotonic = true;
|
|
|
|
|
|
|
|
for &minor_idx in minor_idx_in_lane {
|
|
|
|
if minor_idx >= minor_dim {
|
|
|
|
return Err(SparseFormatError::from_kind_and_msg(
|
|
|
|
SparseFormatErrorKind::IndexOutOfBounds,
|
|
|
|
"A minor index is out of bounds.",
|
|
|
|
));
|
|
|
|
}
|
|
|
|
|
|
|
|
if let Some(prev) = prev {
|
|
|
|
if prev >= minor_idx {
|
|
|
|
if !sort {
|
|
|
|
return Err(SparseFormatError::from_kind_and_msg(
|
|
|
|
SparseFormatErrorKind::InvalidStructure,
|
|
|
|
"Minor indices are not strictly monotonically increasing in each lane.",
|
|
|
|
));
|
|
|
|
}
|
|
|
|
monotonic = false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
prev = Some(minor_idx);
|
|
|
|
}
|
|
|
|
|
|
|
|
// sort if indices are not monotonic and sorting is expected
|
|
|
|
if !monotonic && sort {
|
|
|
|
let range_size = range_end - range_start;
|
|
|
|
minor_index_permutation.resize(range_size, 0);
|
|
|
|
compute_sort_permutation(&mut minor_index_permutation, &minor_idx_in_lane);
|
|
|
|
minor_idx_buffer.clear();
|
|
|
|
minor_idx_buffer.extend_from_slice(&minor_idx_in_lane);
|
|
|
|
apply_permutation(
|
|
|
|
&mut minor_indices[range_start..range_end],
|
|
|
|
&minor_idx_buffer,
|
|
|
|
&minor_index_permutation,
|
|
|
|
);
|
|
|
|
|
|
|
|
// check duplicates
|
|
|
|
prev = None;
|
|
|
|
for &minor_idx in &minor_indices[range_start..range_end] {
|
|
|
|
if let Some(prev) = prev {
|
|
|
|
if prev == minor_idx {
|
|
|
|
return Err(SparseFormatError::from_kind_and_msg(
|
|
|
|
SparseFormatErrorKind::DuplicateEntry,
|
|
|
|
"Input data contains duplicate entries.",
|
|
|
|
));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
prev = Some(minor_idx);
|
|
|
|
}
|
|
|
|
|
|
|
|
// sort values if they exist
|
|
|
|
if let Some(values) = values_option.as_mut() {
|
|
|
|
values_buffer.clear();
|
|
|
|
values_buffer.extend_from_slice(&values[range_start..range_end]);
|
|
|
|
apply_permutation(
|
|
|
|
&mut values[range_start..range_end],
|
|
|
|
&values_buffer,
|
|
|
|
&minor_index_permutation,
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|