From b59c4a3216f103c085830307c74ca3ba35a54335 Mon Sep 17 00:00:00 2001 From: Andreas Longva Date: Tue, 22 Dec 2020 10:19:17 +0100 Subject: [PATCH] Refactor most of Csr/CscMatrix logic into helper type CsMatrix Still need to update CSC API so that it mirrors CsrMatrix in terms of get_entry and so on. --- nalgebra-sparse/src/cs.rs | 286 +++++++++++++++++++++++++++++++++++++ nalgebra-sparse/src/csc.rs | 157 ++++++-------------- nalgebra-sparse/src/csr.rs | 210 ++++++++------------------- nalgebra-sparse/src/lib.rs | 2 + 4 files changed, 394 insertions(+), 261 deletions(-) create mode 100644 nalgebra-sparse/src/cs.rs diff --git a/nalgebra-sparse/src/cs.rs b/nalgebra-sparse/src/cs.rs new file mode 100644 index 00000000..7a38055b --- /dev/null +++ b/nalgebra-sparse/src/cs.rs @@ -0,0 +1,286 @@ +use crate::pattern::SparsityPattern; +use crate::{SparseEntry, SparseEntryMut}; + +use std::sync::Arc; +use std::ops::Range; +use std::ptr::slice_from_raw_parts_mut; + +/// An abstract compressed matrix. +/// +/// For the time being, this is only used internally to share implementation between +/// CSR and CSC matrices. +/// +/// A CSR matrix is obtained by associating rows with the major dimension, while a CSC matrix +/// is obtained by associating columns with the major dimension. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CsMatrix { + sparsity_pattern: Arc, + values: Vec +} + +impl CsMatrix { + /// Create a zero matrix with no explicitly stored entries. + #[inline] + pub fn new(major_dim: usize, minor_dim: usize) -> Self { + Self { + sparsity_pattern: Arc::new(SparsityPattern::new(major_dim, minor_dim)), + values: vec![], + } + } + + #[inline] + pub fn pattern(&self) -> &Arc { + &self.sparsity_pattern + } + + #[inline] + pub fn values(&self) -> &[T] { + &self.values + } + + #[inline] + pub fn values_mut(&mut self) -> &mut [T] { + &mut self.values + } + + /// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`. + #[inline] + pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) { + let pattern = self.pattern().as_ref(); + (pattern.major_offsets(), pattern.minor_indices(), &self.values) + } + + /// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`. + #[inline] + pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) { + let pattern = self.sparsity_pattern.as_ref(); + (pattern.major_offsets(), pattern.minor_indices(), &mut self.values) + } + + #[inline] + pub fn pattern_and_values_mut(&mut self) -> (&Arc, &mut [T]) { + (&self.sparsity_pattern, &mut self.values) + } + + #[inline] + pub fn from_pattern_and_values(pattern: Arc, values: Vec) + -> Self { + assert_eq!(pattern.nnz(), values.len(), "Internal error: consumers should verify shape compatibility."); + Self { + sparsity_pattern: pattern, + values, + } + } + + /// Internal method for simplifying access to a lane's data + #[inline] + pub fn get_index_range(&self, row_index: usize) -> Option> { + let row_begin = *self.sparsity_pattern.major_offsets().get(row_index)?; + let row_end = *self.sparsity_pattern.major_offsets().get(row_index + 1)?; + Some(row_begin .. row_end) + } + + pub fn take_pattern_and_values(self) -> (Arc, Vec) { + (self.sparsity_pattern, self.values) + } + + #[inline] + pub fn disassemble(self) -> (Vec, Vec, Vec) { + // Take an Arc to the pattern, which might be the sole reference to the data after + // taking the values. This is important, because it might let us avoid cloning the data + // further below. + let pattern = self.sparsity_pattern; + let values = self.values; + + // Try to take the pattern out of the `Arc` if possible, + // otherwise clone the pattern. + let owned_pattern = Arc::try_unwrap(pattern) + .unwrap_or_else(|arc| SparsityPattern::clone(&*arc)); + let (offsets, indices) = owned_pattern.disassemble(); + + (offsets, indices, values) + } + + /// Returns an entry for the given major/minor indices, or `None` if the indices are out + /// of bounds. + pub fn get_entry(&self, major_index: usize, minor_index: usize) -> Option> { + let row_range = self.get_index_range(major_index)?; + let (_, minor_indices, values) = self.cs_data(); + let minor_indices = &minor_indices[row_range.clone()]; + let values = &values[row_range]; + get_entry_from_slices(self.pattern().minor_dim(), minor_indices, values, minor_index) + } + + /// Returns a mutable entry for the given major/minor indices, or `None` if the indices are out + /// of bounds. + pub fn get_entry_mut(&mut self, major_index: usize, minor_index: usize) + -> Option> { + let row_range = self.get_index_range(major_index)?; + let minor_dim = self.pattern().minor_dim(); + let (_, minor_indices, values) = self.cs_data_mut(); + let minor_indices = &minor_indices[row_range.clone()]; + let values = &mut values[row_range]; + get_mut_entry_from_slices(minor_dim, minor_indices, values, minor_index) + } + + pub fn get_lane(&self, index: usize) -> Option> { + let range = self.get_index_range(index)?; + let (_, minor_indices, values) = self.cs_data(); + Some(CsLane { + minor_indices: &minor_indices[range.clone()], + values: &values[range], + minor_dim: self.pattern().minor_dim() + }) + } + + #[inline] + pub fn get_lane_mut(&mut self, index: usize) -> Option> { + let range = self.get_index_range(index)?; + let minor_dim = self.pattern().minor_dim(); + let (_, minor_indices, values) = self.cs_data_mut(); + Some(CsLaneMut { + minor_dim, + minor_indices: &minor_indices[range.clone()], + values: &mut values[range] + }) + } +} + +pub fn get_entry_from_slices<'a, T>( + minor_dim: usize, + minor_indices: &'a [usize], + values: &'a [T], + global_minor_index: usize) -> Option> { + let local_index = minor_indices.binary_search(&global_minor_index); + if let Ok(local_index) = local_index { + Some(SparseEntry::NonZero(&values[local_index])) + } else if global_minor_index < minor_dim { + Some(SparseEntry::Zero) + } else { + None + } +} + +pub fn get_mut_entry_from_slices<'a, T>( + minor_dim: usize, + minor_indices: &'a [usize], + values: &'a mut [T], + global_minor_indices: usize) -> Option> { + let local_index = minor_indices.binary_search(&global_minor_indices); + if let Ok(local_index) = local_index { + Some(SparseEntryMut::NonZero(&mut values[local_index])) + } else if global_minor_indices < minor_dim { + Some(SparseEntryMut::Zero) + } else { + None + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CsLane<'a, T> { + pub minor_dim: usize, + pub minor_indices: &'a [usize], + pub values: &'a [T] +} + +#[derive(Debug, PartialEq, Eq)] +pub struct CsLaneMut<'a, T> { + pub minor_dim: usize, + pub minor_indices: &'a [usize], + pub values: &'a mut [T] +} + +pub struct CsLaneIter<'a, T> { + // The index of the lane that will be returned on the next iteration + current_lane_idx: usize, + pattern: &'a SparsityPattern, + remaining_values: &'a [T], +} + +impl<'a, T> CsLaneIter<'a, T> { + pub fn new(pattern: &'a SparsityPattern, values: &'a [T]) -> Self { + Self { + current_lane_idx: 0, + pattern, + remaining_values: values + } + } +} + +impl<'a, T> Iterator for CsLaneIter<'a, T> + where + T: 'a +{ + type Item = CsLane<'a, T>; + + fn next(&mut self) -> Option { + let lane = self.pattern.get_lane(self.current_lane_idx); + let minor_dim = self.pattern.minor_dim(); + + if let Some(minor_indices) = lane { + let count = minor_indices.len(); + let values_in_lane = &self.remaining_values[..count]; + self.remaining_values = &self.remaining_values[count ..]; + self.current_lane_idx += 1; + + Some(CsLane { + minor_dim, + minor_indices, + values: values_in_lane + }) + } else { + None + } + } +} + +pub struct CsLaneIterMut<'a, T> { + // The index of the lane that will be returned on the next iteration + current_lane_idx: usize, + pattern: &'a SparsityPattern, + remaining_values: *mut T, +} + +impl<'a, T> CsLaneIterMut<'a, T> { + pub fn new(pattern: &'a SparsityPattern, values: &'a mut [T]) -> Self { + Self { + current_lane_idx: 0, + pattern, + remaining_values: values.as_mut_ptr() + } + } +} + +impl<'a, T> Iterator for CsLaneIterMut<'a, T> + where + T: 'a +{ + type Item = CsLaneMut<'a, T>; + + fn next(&mut self) -> Option { + let lane = self.pattern.get_lane(self.current_lane_idx); + let minor_dim = self.pattern.minor_dim(); + + if let Some(minor_indices) = lane { + let count = minor_indices.len(); + + // Note: I can't think of any way to construct this iterator without unsafe. + let values_in_lane; + unsafe { + values_in_lane = &mut *slice_from_raw_parts_mut(self.remaining_values, count); + self.remaining_values = self.remaining_values.add(count); + } + self.current_lane_idx += 1; + + Some(CsLaneMut { + minor_dim, + minor_indices, + values: values_in_lane + }) + } else { + None + } + } +} + + diff --git a/nalgebra-sparse/src/csc.rs b/nalgebra-sparse/src/csc.rs index a94a5fdc..f39483a2 100644 --- a/nalgebra-sparse/src/csc.rs +++ b/nalgebra-sparse/src/csc.rs @@ -2,13 +2,12 @@ use crate::{SparseFormatError, SparseFormatErrorKind}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; +use crate::csr::CsrMatrix; +use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut}; use std::sync::Arc; use std::slice::{IterMut, Iter}; -use std::ops::Range; use num_traits::Zero; -use std::ptr::slice_from_raw_parts_mut; -use crate::csr::CsrMatrix; use nalgebra::Scalar; /// A CSC representation of a sparse matrix. @@ -21,29 +20,27 @@ use nalgebra::Scalar; #[derive(Debug, Clone, PartialEq, Eq)] pub struct CscMatrix { // Cols are major, rows are minor in the sparsity pattern - sparsity_pattern: Arc, - values: Vec, + cs: CsMatrix, } impl CscMatrix { /// Create a zero CSC matrix with no explicitly stored entries. pub fn new(nrows: usize, ncols: usize) -> Self { Self { - sparsity_pattern: Arc::new(SparsityPattern::new(ncols, nrows)), - values: vec![], + cs: CsMatrix::new(ncols, nrows) } } /// The number of rows in the matrix. #[inline] pub fn nrows(&self) -> usize { - self.sparsity_pattern.minor_dim() + self.cs.pattern().minor_dim() } /// The number of columns in the matrix. #[inline] pub fn ncols(&self) -> usize { - self.sparsity_pattern.major_dim() + self.cs.pattern().major_dim() } /// The number of non-zeros in the matrix. @@ -53,31 +50,31 @@ impl CscMatrix { /// be zero. Corresponds to the number of entries in the sparsity pattern. #[inline] pub fn nnz(&self) -> usize { - self.sparsity_pattern.nnz() + self.pattern().nnz() } /// The column offsets defining part of the CSC format. #[inline] pub fn col_offsets(&self) -> &[usize] { - self.sparsity_pattern.major_offsets() + self.pattern().major_offsets() } /// The row indices defining part of the CSC format. #[inline] pub fn row_indices(&self) -> &[usize] { - self.sparsity_pattern.minor_indices() + self.pattern().minor_indices() } /// The non-zero values defining part of the CSC format. #[inline] pub fn values(&self) -> &[T] { - &self.values + self.cs.values() } /// Mutable access to the non-zero values. #[inline] pub fn values_mut(&mut self) -> &mut [T] { - &mut self.values + self.cs.values_mut() } /// Try to construct a CSC matrix from raw CSC data. @@ -109,8 +106,7 @@ impl CscMatrix { -> Result { if pattern.nnz() == values.len() { Ok(Self { - sparsity_pattern: pattern, - values, + cs: CsMatrix::from_pattern_and_values(pattern, values) }) } else { Err(SparseFormatError::from_kind_and_msg( @@ -140,8 +136,8 @@ impl CscMatrix { /// ``` pub fn triplet_iter(&self) -> CscTripletIter { CscTripletIter { - pattern_iter: self.sparsity_pattern.entries(), - values_iter: self.values.iter() + pattern_iter: self.pattern().entries(), + values_iter: self.values().iter() } } @@ -169,9 +165,10 @@ impl CscMatrix { /// assert_eq!(triplets, vec![(0, 0, 1), (2, 0, 0), (1, 1, 2), (0, 2, 4)]); /// ``` pub fn triplet_iter_mut(&mut self) -> CscTripletIterMut { + let (pattern, values) = self.cs.pattern_and_values_mut(); CscTripletIterMut { - pattern_iter: self.sparsity_pattern.entries(), - values_mut_iter: self.values.iter_mut() + pattern_iter: pattern.entries(), + values_mut_iter: values.iter_mut() } } @@ -200,54 +197,34 @@ impl CscMatrix { /// Return the column at the given column index, or `None` if out of bounds. #[inline] pub fn get_col(&self, index: usize) -> Option> { - let range = self.get_index_range(index)?; - Some(CscCol { - row_indices: &self.sparsity_pattern.minor_indices()[range.clone()], - values: &self.values[range], - nrows: self.nrows() - }) + self.cs + .get_lane(index) + .map(|lane| CscCol { lane }) } /// Mutable column access for the given column index, or `None` if out of bounds. #[inline] pub fn get_col_mut(&mut self, index: usize) -> Option> { - let range = self.get_index_range(index)?; - Some(CscColMut { - nrows: self.nrows(), - row_indices: &self.sparsity_pattern.minor_indices()[range.clone()], - values: &mut self.values[range] - }) - } - - /// Internal method for simplifying access to a column's data. - fn get_index_range(&self, col_index: usize) -> Option> { - let col_begin = *self.sparsity_pattern.major_offsets().get(col_index)?; - let col_end = *self.sparsity_pattern.major_offsets().get(col_index + 1)?; - Some(col_begin .. col_end) + self.cs + .get_lane_mut(index) + .map(|lane| CscColMut { lane }) } /// An iterator over columns in the matrix. pub fn col_iter(&self) -> CscColIter { CscColIter { - current_col_idx: 0, - matrix: self + lane_iter: CsLaneIter::new(self.pattern().as_ref(), self.values()) } } /// A mutable iterator over columns in the matrix. pub fn col_iter_mut(&mut self) -> CscColIterMut { + let (pattern, values) = self.cs.pattern_and_values_mut(); CscColIterMut { - current_col_idx: 0, - pattern: &self.sparsity_pattern, - remaining_values: self.values.as_mut_ptr() + lane_iter: CsLaneIterMut::new(pattern, values) } } - /// Returns the underlying vector containing the values for the explicitly stored entries. - pub fn take_values(self) -> Vec { - self.values - } - /// Disassembles the CSC matrix into its underlying offset, index and value arrays. /// /// If the matrix contains the sole reference to the sparsity pattern, @@ -274,19 +251,7 @@ impl CscMatrix { /// assert_eq!(values2, values); /// ``` pub fn disassemble(self) -> (Vec, Vec, Vec) { - // Take an Arc to the pattern, which might be the sole reference to the data after - // taking the values. This is important, because it might let us avoid cloning the data - // further below. - let pattern = self.sparsity_pattern; - let values = self.values; - - // Try to take the pattern out of the `Arc` if possible, - // otherwise clone the pattern. - let owned_pattern = Arc::try_unwrap(pattern) - .unwrap_or_else(|arc| SparsityPattern::clone(&*arc)); - let (offsets, indices) = owned_pattern.disassemble(); - - (offsets, indices, values) + self.cs.disassemble() } /// Returns the underlying sparsity pattern. @@ -295,15 +260,14 @@ impl CscMatrix { /// the same sparsity pattern for multiple matrices without storing the same pattern multiple /// times in memory. pub fn pattern(&self) -> &Arc { - &self.sparsity_pattern + self.cs.pattern() } /// Reinterprets the CSC matrix as its transpose represented by a CSR matrix. /// /// This operation does not touch the CSC data, and is effectively a no-op. pub fn transpose_as_csr(self) -> CsrMatrix { - let pattern = self.sparsity_pattern; - let values = self.values; + let (pattern, values) = self.cs.take_pattern_and_values(); CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap() } } @@ -422,9 +386,7 @@ impl<'a, T> Iterator for CscTripletIterMut<'a, T> { /// An immutable representation of a column in a CSC matrix. #[derive(Debug, Clone, PartialEq, Eq)] pub struct CscCol<'a, T> { - nrows: usize, - row_indices: &'a [usize], - values: &'a [T], + lane: CsLane<'a, T> } /// A mutable representation of a column in a CSC matrix. @@ -433,9 +395,7 @@ pub struct CscCol<'a, T> { /// to the column cannot be modified. #[derive(Debug, PartialEq, Eq)] pub struct CscColMut<'a, T> { - nrows: usize, - row_indices: &'a [usize], - values: &'a mut [T] + lane: CsLaneMut<'a, T> } /// Implement the methods common to both CscCol and CscColMut @@ -445,25 +405,25 @@ macro_rules! impl_csc_col_common_methods { /// The number of global rows in the column. #[inline] pub fn nrows(&self) -> usize { - self.nrows + self.lane.minor_dim } /// The number of non-zeros in this column. #[inline] pub fn nnz(&self) -> usize { - self.row_indices.len() + self.lane.minor_indices.len() } /// The row indices corresponding to explicitly stored entries in this column. #[inline] pub fn row_indices(&self) -> &[usize] { - self.row_indices + self.lane.minor_indices } /// The values corresponding to explicitly stored entries in this column. #[inline] pub fn values(&self) -> &[T] { - self.values + self.lane.values } } @@ -480,8 +440,8 @@ macro_rules! impl_csc_col_common_methods { pub fn get(&self, global_row_index: usize) -> Option { let local_index = self.row_indices().binary_search(&global_row_index); if let Ok(local_index) = local_index { - Some(self.values[local_index].clone()) - } else if global_row_index < self.nrows { + Some(self.values()[local_index].clone()) + } else if global_row_index < self.lane.minor_dim { Some(T::zero()) } else { None @@ -497,7 +457,7 @@ impl_csc_col_common_methods!(CscColMut<'a, T>); impl<'a, T> CscColMut<'a, T> { /// Mutable access to the values corresponding to explicitly stored entries in this column. pub fn values_mut(&mut self) -> &mut [T] { - self.values + self.lane.values } /// Provides simultaneous access to row indices and mutable values corresponding to the @@ -506,32 +466,28 @@ impl<'a, T> CscColMut<'a, T> { /// This method primarily facilitates low-level access for methods that process data stored /// in CSC format directly. pub fn rows_and_values_mut(&mut self) -> (&[usize], &mut [T]) { - (self.row_indices, self.values) + (self.lane.minor_indices, self.lane.values) } } /// Column iterator for [CscMatrix](struct.CscMatrix.html). pub struct CscColIter<'a, T> { - // The index of the row that will be returned on the next - current_col_idx: usize, - matrix: &'a CscMatrix + lane_iter: CsLaneIter<'a, T> } impl<'a, T> Iterator for CscColIter<'a, T> { type Item = CscCol<'a, T>; fn next(&mut self) -> Option { - let col = self.matrix.get_col(self.current_col_idx); - self.current_col_idx += 1; - col + self.lane_iter + .next() + .map(|lane| CscCol { lane }) } } /// Mutable column iterator for [CscMatrix](struct.CscMatrix.html). pub struct CscColIterMut<'a, T> { - current_col_idx: usize, - pattern: &'a SparsityPattern, - remaining_values: *mut T, + lane_iter: CsLaneIterMut<'a, T> } impl<'a, T> Iterator for CscColIterMut<'a, T> @@ -541,27 +497,8 @@ where type Item = CscColMut<'a, T>; fn next(&mut self) -> Option { - let lane = self.pattern.get_lane(self.current_col_idx); - let nrows = self.pattern.minor_dim(); - - if let Some(row_indices) = lane { - let count = row_indices.len(); - - // Note: I can't think of any way to construct this iterator without unsafe. - let values_in_row; - unsafe { - values_in_row = &mut *slice_from_raw_parts_mut(self.remaining_values, count); - self.remaining_values = self.remaining_values.add(count); - } - self.current_col_idx += 1; - - Some(CscColMut { - nrows, - row_indices, - values: values_in_row - }) - } else { - None - } + self.lane_iter + .next() + .map(|lane| CscColMut { lane }) } } \ No newline at end of file diff --git a/nalgebra-sparse/src/csr.rs b/nalgebra-sparse/src/csr.rs index 20db241b..fa10a69d 100644 --- a/nalgebra-sparse/src/csr.rs +++ b/nalgebra-sparse/src/csr.rs @@ -2,14 +2,13 @@ use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; use crate::csc::CscMatrix; +use crate::cs::{CsMatrix, get_entry_from_slices, get_mut_entry_from_slices, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut}; use nalgebra::Scalar; use num_traits::Zero; use std::sync::Arc; use std::slice::{IterMut, Iter}; -use std::ops::Range; -use std::ptr::slice_from_raw_parts_mut; /// A CSR representation of a sparse matrix. /// @@ -21,29 +20,27 @@ use std::ptr::slice_from_raw_parts_mut; #[derive(Debug, Clone, PartialEq, Eq)] pub struct CsrMatrix { // Rows are major, cols are minor in the sparsity pattern - sparsity_pattern: Arc, - values: Vec, + cs: CsMatrix, } impl CsrMatrix { /// Create a zero CSR matrix with no explicitly stored entries. pub fn new(nrows: usize, ncols: usize) -> Self { Self { - sparsity_pattern: Arc::new(SparsityPattern::new(nrows, ncols)), - values: vec![], + cs: CsMatrix::new(nrows, ncols) } } /// The number of rows in the matrix. #[inline] pub fn nrows(&self) -> usize { - self.sparsity_pattern.major_dim() + self.cs.pattern().major_dim() } /// The number of columns in the matrix. #[inline] pub fn ncols(&self) -> usize { - self.sparsity_pattern.minor_dim() + self.cs.pattern().minor_dim() } /// The number of non-zeros in the matrix. @@ -53,31 +50,33 @@ impl CsrMatrix { /// be zero. Corresponds to the number of entries in the sparsity pattern. #[inline] pub fn nnz(&self) -> usize { - self.sparsity_pattern.nnz() + self.cs.pattern().nnz() } /// The row offsets defining part of the CSR format. #[inline] pub fn row_offsets(&self) -> &[usize] { - self.sparsity_pattern.major_offsets() + let (offsets, _, _) = self.cs.cs_data(); + offsets } /// The column indices defining part of the CSR format. #[inline] pub fn col_indices(&self) -> &[usize] { - self.sparsity_pattern.minor_indices() + let (_, indices, _) = self.cs.cs_data(); + indices } /// The non-zero values defining part of the CSR format. #[inline] pub fn values(&self) -> &[T] { - &self.values + self.cs.values() } /// Mutable access to the non-zero values. #[inline] pub fn values_mut(&mut self) -> &mut [T] { - &mut self.values + self.cs.values_mut() } /// Try to construct a CSR matrix from raw CSR data. @@ -109,8 +108,7 @@ impl CsrMatrix { -> Result { if pattern.nnz() == values.len() { Ok(Self { - sparsity_pattern: pattern, - values, + cs: CsMatrix::from_pattern_and_values(pattern, values) }) } else { Err(SparseFormatError::from_kind_and_msg( @@ -119,7 +117,6 @@ impl CsrMatrix { } } - /// An iterator over non-zero triplets (i, j, v). /// /// The iteration happens in row-major fashion, meaning that i increases monotonically, @@ -140,8 +137,8 @@ impl CsrMatrix { /// ``` pub fn triplet_iter(&self) -> CsrTripletIter { CsrTripletIter { - pattern_iter: self.sparsity_pattern.entries(), - values_iter: self.values.iter() + pattern_iter: self.pattern().entries(), + values_iter: self.values().iter() } } @@ -169,9 +166,10 @@ impl CsrMatrix { /// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 0)]); /// ``` pub fn triplet_iter_mut(&mut self) -> CsrTripletIterMut { + let (pattern, values) = self.cs.pattern_and_values_mut(); CsrTripletIterMut { - pattern_iter: self.sparsity_pattern.entries(), - values_mut_iter: self.values.iter_mut() + pattern_iter: pattern.entries(), + values_mut_iter: values.iter_mut() } } @@ -200,54 +198,34 @@ impl CsrMatrix { /// Return the row at the given row index, or `None` if out of bounds. #[inline] pub fn get_row(&self, index: usize) -> Option> { - let range = self.get_index_range(index)?; - Some(CsrRow { - col_indices: &self.sparsity_pattern.minor_indices()[range.clone()], - values: &self.values[range], - ncols: self.ncols() - }) + self.cs + .get_lane(index) + .map(|lane| CsrRow { lane }) } /// Mutable row access for the given row index, or `None` if out of bounds. #[inline] pub fn get_row_mut(&mut self, index: usize) -> Option> { - let range = self.get_index_range(index)?; - Some(CsrRowMut { - ncols: self.ncols(), - col_indices: &self.sparsity_pattern.minor_indices()[range.clone()], - values: &mut self.values[range] - }) - } - - /// Internal method for simplifying access to a row's data. - fn get_index_range(&self, row_index: usize) -> Option> { - let row_begin = *self.sparsity_pattern.major_offsets().get(row_index)?; - let row_end = *self.sparsity_pattern.major_offsets().get(row_index + 1)?; - Some(row_begin .. row_end) + self.cs + .get_lane_mut(index) + .map(|lane| CsrRowMut { lane }) } /// An iterator over rows in the matrix. pub fn row_iter(&self) -> CsrRowIter { CsrRowIter { - current_row_idx: 0, - matrix: self + lane_iter: CsLaneIter::new(self.pattern().as_ref(), self.values()) } } /// A mutable iterator over rows in the matrix. pub fn row_iter_mut(&mut self) -> CsrRowIterMut { + let (pattern, values) = self.cs.pattern_and_values_mut(); CsrRowIterMut { - current_row_idx: 0, - pattern: &self.sparsity_pattern, - remaining_values: self.values.as_mut_ptr() + lane_iter: CsLaneIterMut::new(pattern, values), } } - /// Returns the underlying vector containing the values for the explicitly stored entries. - pub fn take_values(self) -> Vec { - self.values - } - /// Disassembles the CSR matrix into its underlying offset, index and value arrays. /// /// If the matrix contains the sole reference to the sparsity pattern, @@ -274,19 +252,7 @@ impl CsrMatrix { /// assert_eq!(values2, values); /// ``` pub fn disassemble(self) -> (Vec, Vec, Vec) { - // Take an Arc to the pattern, which might be the sole reference to the data after - // taking the values. This is important, because it might let us avoid cloning the data - // further below. - let pattern = self.sparsity_pattern; - let values = self.values; - - // Try to take the pattern out of the `Arc` if possible, - // otherwise clone the pattern. - let owned_pattern = Arc::try_unwrap(pattern) - .unwrap_or_else(|arc| SparsityPattern::clone(&*arc)); - let (offsets, indices) = owned_pattern.disassemble(); - - (offsets, indices, values) + self.cs.disassemble() } /// Returns the underlying sparsity pattern. @@ -295,15 +261,14 @@ impl CsrMatrix { /// the same sparsity pattern for multiple matrices without storing the same pattern multiple /// times in memory. pub fn pattern(&self) -> &Arc { - &self.sparsity_pattern + self.cs.pattern() } /// Reinterprets the CSR matrix as its transpose represented by a CSC matrix. /// /// This operation does not touch the CSR data, and is effectively a no-op. pub fn transpose_as_csc(self) -> CscMatrix { - let pattern = self.sparsity_pattern; - let values = self.values; + let (pattern, values) = self.cs.take_pattern_and_values(); CscMatrix::try_from_pattern_and_values(pattern, values).unwrap() } @@ -312,10 +277,7 @@ impl CsrMatrix { /// Each call to this function incurs the cost of a binary search among the explicitly /// stored column entries for the given row. pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option> { - let row_range = self.get_index_range(row_index)?; - let col_indices = &self.col_indices()[row_range.clone()]; - let values = &self.values()[row_range]; - get_entry_from_slices(self.ncols(), col_indices, values, col_index) + self.cs.get_entry(row_index, col_index) } /// Returns a mutable entry for the given row/col indices, or `None` if the indices are out @@ -325,12 +287,7 @@ impl CsrMatrix { /// stored column entries for the given row. pub fn get_entry_mut(&mut self, row_index: usize, col_index: usize) -> Option> { - let row_range = self.get_index_range(row_index)?; - let ncols = self.ncols(); - let (_, col_indices, values) = self.csr_data_mut(); - let col_indices = &col_indices[row_range.clone()]; - let values = &mut values[row_range]; - get_mut_entry_from_slices(ncols, col_indices, values, col_index) + self.cs.get_entry_mut(row_index, col_index) } /// Returns an entry for the given row/col indices. @@ -361,14 +318,13 @@ impl CsrMatrix { /// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data. pub fn csr_data(&self) -> (&[usize], &[usize], &[T]) { - (self.row_offsets(), self.col_indices(), self.values()) + self.cs.cs_data() } /// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data, /// where the `values` array is mutable. pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) { - let pattern = self.sparsity_pattern.as_ref(); - (pattern.major_offsets(), pattern.minor_indices(), &mut self.values) + self.cs.cs_data_mut() } } @@ -460,9 +416,7 @@ impl<'a, T> Iterator for CsrTripletIterMut<'a, T> { /// An immutable representation of a row in a CSR matrix. #[derive(Debug, Clone, PartialEq, Eq)] pub struct CsrRow<'a, T> { - ncols: usize, - col_indices: &'a [usize], - values: &'a [T], + lane: CsLane<'a, T> } /// A mutable representation of a row in a CSR matrix. @@ -471,9 +425,7 @@ pub struct CsrRow<'a, T> { /// to the row cannot be modified. #[derive(Debug, PartialEq, Eq)] pub struct CsrRowMut<'a, T> { - ncols: usize, - col_indices: &'a [usize], - values: &'a mut [T] + lane: CsLaneMut<'a, T> } /// Implement the methods common to both CsrRow and CsrRowMut @@ -483,25 +435,25 @@ macro_rules! impl_csr_row_common_methods { /// The number of global columns in the row. #[inline] pub fn ncols(&self) -> usize { - self.ncols + self.lane.minor_dim } /// The number of non-zeros in this row. #[inline] pub fn nnz(&self) -> usize { - self.col_indices.len() + self.lane.minor_indices.len() } /// The column indices corresponding to explicitly stored entries in this row. #[inline] pub fn col_indices(&self) -> &[usize] { - self.col_indices + self.lane.minor_indices } /// The values corresponding to explicitly stored entries in this row. #[inline] pub fn values(&self) -> &[T] { - self.values + self.lane.values } /// Returns an entry for the given global column index. @@ -509,47 +461,23 @@ macro_rules! impl_csr_row_common_methods { /// Each call to this function incurs the cost of a binary search among the explicitly /// stored column entries. pub fn get_entry(&self, global_col_index: usize) -> Option> { - get_entry_from_slices(self.ncols, self.col_indices, self.values, global_col_index) + get_entry_from_slices( + self.lane.minor_dim, + self.lane.minor_indices, + self.lane.values, + global_col_index) } } } } -fn get_entry_from_slices<'a, T>(ncols: usize, - col_indices: &'a [usize], - values: &'a [T], - global_col_index: usize) -> Option> { - let local_index = col_indices.binary_search(&global_col_index); - if let Ok(local_index) = local_index { - Some(SparseEntry::NonZero(&values[local_index])) - } else if global_col_index < ncols { - Some(SparseEntry::Zero) - } else { - None - } -} - -fn get_mut_entry_from_slices<'a, T>(ncols: usize, - col_indices: &'a [usize], - values: &'a mut [T], - global_col_index: usize) -> Option> { - let local_index = col_indices.binary_search(&global_col_index); - if let Ok(local_index) = local_index { - Some(SparseEntryMut::NonZero(&mut values[local_index])) - } else if global_col_index < ncols { - Some(SparseEntryMut::Zero) - } else { - None - } -} - impl_csr_row_common_methods!(CsrRow<'a, T>); impl_csr_row_common_methods!(CsrRowMut<'a, T>); impl<'a, T> CsrRowMut<'a, T> { /// Mutable access to the values corresponding to explicitly stored entries in this row. pub fn values_mut(&mut self) -> &mut [T] { - self.values + self.lane.values } /// Provides simultaneous access to column indices and mutable values corresponding to the @@ -558,37 +486,36 @@ impl<'a, T> CsrRowMut<'a, T> { /// This method primarily facilitates low-level access for methods that process data stored /// in CSR format directly. pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) { - (self.col_indices, self.values) + (self.lane.minor_indices, self.lane.values) } /// Returns a mutable entry for the given global column index. pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option> { - get_mut_entry_from_slices(self.ncols, self.col_indices, self.values, global_col_index) + get_mut_entry_from_slices(self.lane.minor_dim, + self.lane.minor_indices, + self.lane.values, + global_col_index) } } /// Row iterator for [CsrMatrix](struct.CsrMatrix.html). pub struct CsrRowIter<'a, T> { - // The index of the row that will be returned on the next - current_row_idx: usize, - matrix: &'a CsrMatrix + lane_iter: CsLaneIter<'a, T> } impl<'a, T> Iterator for CsrRowIter<'a, T> { type Item = CsrRow<'a, T>; fn next(&mut self) -> Option { - let row = self.matrix.get_row(self.current_row_idx); - self.current_row_idx += 1; - row + self.lane_iter + .next() + .map(|lane| CsrRow { lane }) } } /// Mutable row iterator for [CsrMatrix](struct.CsrMatrix.html). pub struct CsrRowIterMut<'a, T> { - current_row_idx: usize, - pattern: &'a SparsityPattern, - remaining_values: *mut T, + lane_iter: CsLaneIterMut<'a, T> } impl<'a, T> Iterator for CsrRowIterMut<'a, T> @@ -598,27 +525,8 @@ where type Item = CsrRowMut<'a, T>; fn next(&mut self) -> Option { - let lane = self.pattern.get_lane(self.current_row_idx); - let ncols = self.pattern.minor_dim(); - - if let Some(col_indices) = lane { - let count = col_indices.len(); - - // Note: I can't think of any way to construct this iterator without unsafe. - let values_in_row; - unsafe { - values_in_row = &mut *slice_from_raw_parts_mut(self.remaining_values, count); - self.remaining_values = self.remaining_values.add(count); - } - self.current_row_idx += 1; - - Some(CsrRowMut { - ncols, - col_indices, - values: values_in_row - }) - } else { - None - } + self.lane_iter + .next() + .map(|lane| CsrRowMut { lane }) } } \ No newline at end of file diff --git a/nalgebra-sparse/src/lib.rs b/nalgebra-sparse/src/lib.rs index 536f4b75..1812a60a 100644 --- a/nalgebra-sparse/src/lib.rs +++ b/nalgebra-sparse/src/lib.rs @@ -90,6 +90,8 @@ pub mod pattern; pub mod ops; pub mod convert; +mod cs; + #[cfg(feature = "proptest-support")] pub mod proptest;