Initial CSR and SparsityPattern impls (WIP)
This commit is contained in:
parent
1dbccfeb7c
commit
b0ffd55962
|
@ -0,0 +1,195 @@
|
||||||
|
use crate::{SparsityPattern, SparseFormatError};
|
||||||
|
use crate::iter::SparsityPatternIter;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::slice::{IterMut, Iter};
|
||||||
|
|
||||||
|
/// A CSR representation of a sparse matrix.
|
||||||
|
///
|
||||||
|
/// The Compressed Row Storage (CSR) format is well-suited as a general-purpose storage format
|
||||||
|
/// for many sparse matrix applications.
|
||||||
|
///
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct CsrMatrix<T> {
|
||||||
|
// Rows are major, cols are minor in the sparsity pattern
|
||||||
|
sparsity_pattern: Arc<SparsityPattern>,
|
||||||
|
values: Vec<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> CsrMatrix<T> {
|
||||||
|
/// Create a zero CSR matrix with no explicitly stored entries.
|
||||||
|
pub fn new(nrows: usize, ncols: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
sparsity_pattern: Arc::new(SparsityPattern::new(nrows, ncols)),
|
||||||
|
values: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of rows in the matrix.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn nrows(&self) -> usize {
|
||||||
|
self.sparsity_pattern.major_dim()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of columns in the matrix.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn ncols(&self) -> usize {
|
||||||
|
self.sparsity_pattern.minor_dim()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of non-zeros in the matrix.
|
||||||
|
///
|
||||||
|
/// Note that this corresponds to the number of explicitly stored entries, *not* the actual
|
||||||
|
/// number of algebraically zero entries in the matrix. Explicitly stored entries can still
|
||||||
|
/// be zero. Corresponds to the number of entries in the sparsity pattern.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn nnz(&self) -> usize {
|
||||||
|
self.sparsity_pattern.nnz()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The row offsets defining part of the CSR format.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn row_offsets(&self) -> &[usize] {
|
||||||
|
self.sparsity_pattern.major_offsets()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The column indices defining part of the CSR format.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn column_indices(&self) -> &[usize] {
|
||||||
|
self.sparsity_pattern.minor_indices()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The non-zero values defining part of the CSR format.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn values(&self) -> &[T] {
|
||||||
|
&self.values
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mutable access to the non-zero values.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
|
&mut self.values
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to construct a CSR matrix from raw CSR data.
|
||||||
|
///
|
||||||
|
/// It is assumed that each row contains unique and sorted column indices that are in
|
||||||
|
/// bounds with respect to the number of columns in the matrix. If this is not the case,
|
||||||
|
/// an error is returned to indicate the failure.
|
||||||
|
///
|
||||||
|
/// Panics
|
||||||
|
/// ------
|
||||||
|
/// Panics if the lengths of the provided arrays are not compatible with the CSR format.
|
||||||
|
pub fn try_from_csr_data(
|
||||||
|
num_rows: usize,
|
||||||
|
num_cols: usize,
|
||||||
|
row_offsets: Vec<usize>,
|
||||||
|
col_indices: Vec<usize>,
|
||||||
|
values: Vec<T>,
|
||||||
|
) -> Result<Self, SparseFormatError> {
|
||||||
|
assert_eq!(col_indices.len(), values.len(),
|
||||||
|
"Number of values and column indices must be the same");
|
||||||
|
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||||
|
num_rows, num_cols, row_offsets, col_indices)?;
|
||||||
|
Ok(Self {
|
||||||
|
sparsity_pattern: Arc::new(pattern),
|
||||||
|
values,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An iterator over non-zero triplets (i, j, v).
|
||||||
|
///
|
||||||
|
/// The iteration happens in row-major fashion, meaning that i increases monotonically,
|
||||||
|
/// and j increases monotonically within each row.
|
||||||
|
///
|
||||||
|
/// Examples
|
||||||
|
/// --------
|
||||||
|
/// ```
|
||||||
|
/// # use nalgebra_sparse::CsrMatrix;
|
||||||
|
/// let row_offsets = vec![0, 2, 3, 4];
|
||||||
|
/// let col_indices = vec![0, 2, 1, 0];
|
||||||
|
/// let values = vec![1, 2, 3, 4];
|
||||||
|
/// let mut csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
|
||||||
|
/// .unwrap();
|
||||||
|
///
|
||||||
|
/// let triplets: Vec<_> = csr.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
|
||||||
|
/// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 4)]);
|
||||||
|
/// ```
|
||||||
|
pub fn triplet_iter(&self) -> CsrTripletIter<T> {
|
||||||
|
CsrTripletIter {
|
||||||
|
pattern_iter: self.sparsity_pattern.entries(),
|
||||||
|
values_iter: self.values.iter()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A mutable iterator over non-zero triplets (i, j, v).
|
||||||
|
///
|
||||||
|
/// Iteration happens in the same order as for [triplet_iter](#method.triplet_iter).
|
||||||
|
///
|
||||||
|
/// Examples
|
||||||
|
/// --------
|
||||||
|
/// ```
|
||||||
|
/// # use nalgebra_sparse::CsrMatrix;
|
||||||
|
/// # let row_offsets = vec![0, 2, 3, 4];
|
||||||
|
/// # let col_indices = vec![0, 2, 1, 0];
|
||||||
|
/// # let values = vec![1, 2, 3, 4];
|
||||||
|
/// // Using the same data as in the `triplet_iter` example
|
||||||
|
/// let mut csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
|
||||||
|
/// .unwrap();
|
||||||
|
///
|
||||||
|
/// // Zero out lower-triangular terms
|
||||||
|
/// csr.triplet_iter_mut()
|
||||||
|
/// .filter(|(i, j, _)| j < i)
|
||||||
|
/// .for_each(|(_, _, v)| *v = 0);
|
||||||
|
///
|
||||||
|
/// let triplets: Vec<_> = csr.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
|
||||||
|
/// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 0)]);
|
||||||
|
/// ```
|
||||||
|
pub fn triplet_iter_mut(&mut self) -> CsrTripletIterMut<T> {
|
||||||
|
CsrTripletIterMut {
|
||||||
|
pattern_iter: self.sparsity_pattern.entries(),
|
||||||
|
values_mut_iter: self.values.iter_mut()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CsrTripletIter<'a, T> {
|
||||||
|
pattern_iter: SparsityPatternIter<'a>,
|
||||||
|
values_iter: Iter<'a, T>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CsrTripletIter<'a, T> {
|
||||||
|
type Item = (usize, usize, &'a T);
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let next_entry = self.pattern_iter.next();
|
||||||
|
let next_value = self.values_iter.next();
|
||||||
|
|
||||||
|
match (next_entry, next_value) {
|
||||||
|
(Some((i, j)), Some(v)) => Some((i, j, v)),
|
||||||
|
_ => None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CsrTripletIterMut<'a, T> {
|
||||||
|
pattern_iter: SparsityPatternIter<'a>,
|
||||||
|
values_mut_iter: IterMut<'a, T>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
|
||||||
|
type Item = (usize, usize, &'a mut T);
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let next_entry = self.pattern_iter.next();
|
||||||
|
let next_value = self.values_mut_iter.next();
|
||||||
|
|
||||||
|
match (next_entry, next_value) {
|
||||||
|
(Some((i, j)), Some(v)) => Some((i, j, v)),
|
||||||
|
_ => None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,167 @@
|
||||||
|
use crate::SparseFormatError;
|
||||||
|
|
||||||
|
/// A representation of the sparsity pattern of a CSR or COO matrix.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
// TODO: Make SparsityPattern parametrized by index type
|
||||||
|
// (need a solid abstraction for index types though)
|
||||||
|
pub struct SparsityPattern {
|
||||||
|
major_offsets: Vec<usize>,
|
||||||
|
minor_indices: Vec<usize>,
|
||||||
|
minor_dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SparsityPattern {
|
||||||
|
/// Create a sparsity pattern of the given dimensions without explicitly stored entries.
|
||||||
|
pub fn new(major_dim: usize, minor_dim: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
major_offsets: vec![0; major_dim + 1],
|
||||||
|
minor_indices: vec![],
|
||||||
|
minor_dim,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The offsets for the major dimension.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn major_offsets(&self) -> &[usize] {
|
||||||
|
&self.major_offsets
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The indices for the minor dimension.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn minor_indices(&self) -> &[usize] {
|
||||||
|
&self.minor_indices
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The major dimension.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn major_dim(&self) -> usize {
|
||||||
|
assert!(self.major_offsets.len() > 0);
|
||||||
|
self.major_offsets.len() - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The minor dimension.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn minor_dim(&self) -> usize {
|
||||||
|
self.minor_dim
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of "non-zeros", i.e. explicitly stored entries in the pattern.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn nnz(&self) -> usize {
|
||||||
|
self.minor_indices.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the lane at the given index.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn lane(&self, major_index: usize) -> Option<&[usize]> {
|
||||||
|
let offset_begin = *self.major_offsets().get(major_index)?;
|
||||||
|
let offset_end = *self.major_offsets().get(major_index + 1)?;
|
||||||
|
Some(&self.minor_indices()[offset_begin..offset_end])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to construct a sparsity pattern from the given dimensions, major offsets
|
||||||
|
/// and minor indices.
|
||||||
|
///
|
||||||
|
/// Returns an error if the data does not conform to the requirements.
|
||||||
|
///
|
||||||
|
/// TODO: Maybe we should not do any assertions in any of the construction functions
|
||||||
|
pub fn try_from_offsets_and_indices(
|
||||||
|
major_dim: usize,
|
||||||
|
minor_dim: usize,
|
||||||
|
major_offsets: Vec<usize>,
|
||||||
|
minor_indices: Vec<usize>,
|
||||||
|
) -> Result<Self, SparseFormatError> {
|
||||||
|
assert_eq!(major_offsets.len(), major_dim + 1);
|
||||||
|
assert_eq!(*major_offsets.last().unwrap(), minor_indices.len());
|
||||||
|
Ok(Self {
|
||||||
|
major_offsets,
|
||||||
|
minor_indices,
|
||||||
|
minor_dim,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An iterator over the explicitly stored "non-zero" entries (i, j).
|
||||||
|
///
|
||||||
|
/// The iteration happens in a lane-major fashion, meaning that the lane index i
|
||||||
|
/// increases monotonically. and the minor index j increases monotonically within each
|
||||||
|
/// lane i.
|
||||||
|
///
|
||||||
|
/// Examples
|
||||||
|
/// --------
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # use nalgebra_sparse::{SparsityPattern};
|
||||||
|
/// let offsets = vec![0, 2, 3, 4];
|
||||||
|
/// let minor_indices = vec![0, 2, 1, 0];
|
||||||
|
/// let pattern = SparsityPattern::try_from_offsets_and_indices(3, 4, offsets, minor_indices)
|
||||||
|
/// .unwrap();
|
||||||
|
///
|
||||||
|
/// let entries: Vec<_> = pattern.entries().collect();
|
||||||
|
/// assert_eq!(entries, vec![(0, 0), (0, 2), (1, 1), (2, 0)]);
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
pub fn entries(&self) -> SparsityPatternIter {
|
||||||
|
SparsityPatternIter::from_pattern(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct SparsityPatternIter<'a> {
|
||||||
|
// See implementation of Iterator::next for an explanation of how these members are used
|
||||||
|
major_offsets: &'a [usize],
|
||||||
|
minor_indices: &'a [usize],
|
||||||
|
current_lane_idx: usize,
|
||||||
|
remaining_minors_in_lane: &'a [usize],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> SparsityPatternIter<'a> {
|
||||||
|
fn from_pattern(pattern: &'a SparsityPattern) -> Self {
|
||||||
|
let first_lane_end = pattern.major_offsets().get(1).unwrap_or(&0);
|
||||||
|
let minors_in_first_lane = &pattern.minor_indices()[0 .. *first_lane_end];
|
||||||
|
Self {
|
||||||
|
major_offsets: pattern.major_offsets(),
|
||||||
|
minor_indices: pattern.minor_indices(),
|
||||||
|
current_lane_idx: 0,
|
||||||
|
remaining_minors_in_lane: minors_in_first_lane
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Iterator for SparsityPatternIter<'a> {
|
||||||
|
type Item = (usize, usize);
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
// We ensure fast iteration across each lane by iteratively "draining" a slice
|
||||||
|
// corresponding to the remaining column indices in the particular lane.
|
||||||
|
// When we reach the end of this slice, we are at the end of a lane,
|
||||||
|
// and we must do some bookkeeping for preparing the iteration of the next lane
|
||||||
|
// (or stop iteration if we're through all lanes).
|
||||||
|
// This way we can avoid doing unnecessary bookkeeping on every iteration,
|
||||||
|
// instead paying a small price whenever we jump to a new lane.
|
||||||
|
if let Some(minor_idx) = self.remaining_minors_in_lane.first() {
|
||||||
|
let item = Some((self.current_lane_idx, *minor_idx));
|
||||||
|
self.remaining_minors_in_lane = &self.remaining_minors_in_lane[1..];
|
||||||
|
item
|
||||||
|
} else {
|
||||||
|
loop {
|
||||||
|
// Keep skipping lanes until we found a non-empty lane or there are no more lanes
|
||||||
|
if self.current_lane_idx + 2 >= self.major_offsets.len() {
|
||||||
|
// We've processed all lanes, so we're at the end of the iterator
|
||||||
|
// (note: keep in mind that offsets.len() == major_dim() + 1, hence we need +2)
|
||||||
|
return None;
|
||||||
|
} else {
|
||||||
|
// Bump lane index and check if the lane is non-empty
|
||||||
|
self.current_lane_idx += 1;
|
||||||
|
let lower = self.major_offsets[self.current_lane_idx];
|
||||||
|
let upper = self.major_offsets[self.current_lane_idx + 1];
|
||||||
|
if upper > lower {
|
||||||
|
self.remaining_minors_in_lane = &self.minor_indices[(lower + 1) .. upper];
|
||||||
|
return Some((self.current_lane_idx, self.minor_indices[lower]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use nalgebra_sparse::{CooMatrix, SparsePatternError};
|
use nalgebra_sparse::{CooMatrix, SparseFormatError};
|
||||||
use nalgebra::DMatrix;
|
use nalgebra::DMatrix;
|
||||||
use crate::assert_panics;
|
use crate::assert_panics;
|
||||||
|
|
||||||
|
@ -91,25 +91,25 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
|
||||||
{
|
{
|
||||||
// 0x0 matrix
|
// 0x0 matrix
|
||||||
let result = CooMatrix::<i32>::try_from_triplets(0, 0, vec![0], vec![0], vec![2]);
|
let result = CooMatrix::<i32>::try_from_triplets(0, 0, vec![0], vec![0], vec![2]);
|
||||||
assert!(matches!(result, Err(SparsePatternError::IndexOutOfBounds(_))));
|
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// 1x1 matrix, row out of bounds
|
// 1x1 matrix, row out of bounds
|
||||||
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![0], vec![2]);
|
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![0], vec![2]);
|
||||||
assert!(matches!(result, Err(SparsePatternError::IndexOutOfBounds(_))));
|
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// 1x1 matrix, col out of bounds
|
// 1x1 matrix, col out of bounds
|
||||||
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![0], vec![1], vec![2]);
|
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![0], vec![1], vec![2]);
|
||||||
assert!(matches!(result, Err(SparsePatternError::IndexOutOfBounds(_))));
|
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// 1x1 matrix, row and col out of bounds
|
// 1x1 matrix, row and col out of bounds
|
||||||
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![1], vec![2]);
|
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![1], vec![2]);
|
||||||
assert!(matches!(result, Err(SparsePatternError::IndexOutOfBounds(_))));
|
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -118,7 +118,7 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
|
||||||
let j = vec![0, 2, 1, 3, 3];
|
let j = vec![0, 2, 1, 3, 3];
|
||||||
let v = vec![2, 3, 7, 3, 1];
|
let v = vec![2, 3, 7, 3, 1];
|
||||||
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
||||||
assert!(matches!(result, Err(SparsePatternError::IndexOutOfBounds(_))));
|
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -127,7 +127,7 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
|
||||||
let j = vec![0, 2, 1, 5, 3];
|
let j = vec![0, 2, 1, 5, 3];
|
||||||
let v = vec![2, 3, 7, 3, 1];
|
let v = vec![2, 3, 7, 3, 1];
|
||||||
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
||||||
assert!(matches!(result, Err(SparsePatternError::IndexOutOfBounds(_))));
|
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue