Merge pull request #904 from CattleProdigy/coo-push-mat

Add push_matrix fcn to COO
This commit is contained in:
Sébastien Crozet 2021-06-14 14:39:16 +02:00 committed by GitHub
commit 2287e5088a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 132 additions and 0 deletions

View File

@ -45,6 +45,43 @@ pub struct CooMatrix<T> {
values: Vec<T>, values: Vec<T>,
} }
impl<T: na::Scalar> CooMatrix<T> {
/// Pushes a dense matrix into the sparse one.
///
/// This adds the dense matrix `m` starting at the `r`th row and `c`th column
/// to the matrix.
///
/// Panics
/// ------
///
/// Panics if any part of the dense matrix is out of bounds of the sparse matrix
/// when inserted at `(r, c)`.
#[inline]
pub fn push_matrix<R: na::Dim, C: na::Dim, S: nalgebra::storage::Storage<T, R, C>>(
&mut self,
r: usize,
c: usize,
m: &na::Matrix<T, R, C, S>,
) {
let block_nrows = m.nrows();
let block_ncols = m.ncols();
let max_row_with_block = r + block_nrows - 1;
let max_col_with_block = c + block_ncols - 1;
assert!(max_row_with_block < self.nrows);
assert!(max_col_with_block < self.ncols);
self.reserve(block_ncols * block_nrows);
for (col_idx, col) in m.column_iter().enumerate() {
for (row_idx, v) in col.iter().enumerate() {
self.row_indices.push(r + row_idx);
self.col_indices.push(c + col_idx);
self.values.push(v.clone());
}
}
}
}
impl<T> CooMatrix<T> { impl<T> CooMatrix<T> {
/// Construct a zero COO matrix of the given dimensions. /// Construct a zero COO matrix of the given dimensions.
/// ///

View File

@ -93,6 +93,9 @@
//! coo.push(1, 2, 1.3); //! coo.push(1, 2, 1.3);
//! coo.push(2, 2, 4.1); //! coo.push(2, 2, 4.1);
//! //!
//! // ... or add entire dense matrices like so:
//! // coo.push_matrix(0, 0, &dense);
//!
//! // The simplest way to construct a CSR matrix is to first construct a COO matrix, and //! // The simplest way to construct a CSR matrix is to first construct a COO matrix, and
//! // then convert it to CSR. The `From` trait is implemented for conversions between different //! // then convert it to CSR. The `From` trait is implemented for conversions between different
//! // sparse matrix types. //! // sparse matrix types.

View File

@ -252,3 +252,95 @@ fn coo_push_out_of_bounds_entries() {
assert_panics!(coo.clone().push(3, 2, 1)); assert_panics!(coo.clone().push(3, 2, 1));
} }
} }
#[test]
fn coo_push_matrix_valid_entries() {
let mut coo = CooMatrix::new(3, 3);
// Works with static
{
// new is row-major...
let inserted = nalgebra::SMatrix::<i32, 2, 2>::new(1, 2, 3, 4);
coo.push_matrix(1, 1, &inserted);
// insert happens column-major, so expect transposition when read this way
assert_eq!(
coo.triplet_iter().collect::<Vec<_>>(),
vec![(1, 1, &1), (2, 1, &3), (1, 2, &2), (2, 2, &4)]
);
}
// Works with owned dynamic
{
let inserted = nalgebra::DMatrix::<i32>::repeat(1, 2, 5);
coo.push_matrix(0, 0, &inserted);
assert_eq!(
coo.triplet_iter().collect::<Vec<_>>(),
vec![
(1, 1, &1),
(2, 1, &3),
(1, 2, &2),
(2, 2, &4),
(0, 0, &5),
(0, 1, &5)
]
);
}
// Works with sliced
{
let source = nalgebra::SMatrix::<i32, 2, 2>::new(6, 7, 8, 9);
let sliced = source.fixed_slice::<2, 1>(0, 0);
coo.push_matrix(1, 0, &sliced);
assert_eq!(
coo.triplet_iter().collect::<Vec<_>>(),
vec![
(1, 1, &1),
(2, 1, &3),
(1, 2, &2),
(2, 2, &4),
(0, 0, &5),
(0, 1, &5),
(1, 0, &6),
(2, 0, &8)
]
);
}
}
#[test]
fn coo_push_matrix_out_of_bounds_entries() {
// 0x0
{
let inserted = nalgebra::SMatrix::<i32, 1, 1>::new(1);
assert_panics!(CooMatrix::new(0, 0).push_matrix(0, 0, &inserted));
}
// 0x1
{
let inserted = nalgebra::SMatrix::<i32, 1, 1>::new(1);
assert_panics!(CooMatrix::new(1, 0).push_matrix(0, 0, &inserted));
}
// 1x0
{
let inserted = nalgebra::SMatrix::<i32, 1, 1>::new(1);
assert_panics!(CooMatrix::new(0, 1).push_matrix(0, 0, &inserted));
}
// 3x3 exceeds col-dim
{
let inserted = nalgebra::SMatrix::<i32, 1, 2>::repeat(1);
assert_panics!(CooMatrix::new(3, 3).push_matrix(0, 2, &inserted));
}
// 3x3 exceeds row-dim
{
let inserted = nalgebra::SMatrix::<i32, 2, 1>::repeat(1);
assert_panics!(CooMatrix::new(3, 3).push_matrix(2, 0, &inserted));
}
// 3x3 exceeds row-dim and row-dim
{
let inserted = nalgebra::SMatrix::<i32, 2, 2>::repeat(1);
assert_panics!(CooMatrix::new(3, 3).push_matrix(2, 2, &inserted));
}
}