COO: add push_matrix fn
- This function allows one to add entire dense matrices to a sparse COO matrix. - Added a small mention of this new function in the example in lib.rs
This commit is contained in:
parent
543f964610
commit
e6e7efba8a
|
@ -45,6 +45,43 @@ pub struct CooMatrix<T> {
|
|||
values: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T: na::Scalar> CooMatrix<T> {
|
||||
/// Pushes a dense matrix into the sparse one.
|
||||
///
|
||||
/// This adds the dense matrix `m` starting at the `r`th row and `c`th column
|
||||
/// to the matrix.
|
||||
///
|
||||
/// Panics
|
||||
/// ------
|
||||
///
|
||||
/// Panics if any part of the dense matrix is out of bounds of the sparse matrix
|
||||
/// when inserted at `(r, c)`.
|
||||
#[inline]
|
||||
pub fn push_matrix<R: na::Dim, C: na::Dim, S: nalgebra::storage::Storage<T, R, C>>(
|
||||
&mut self,
|
||||
r: usize,
|
||||
c: usize,
|
||||
m: &na::Matrix<T, R, C, S>,
|
||||
) {
|
||||
let block_nrows = m.nrows();
|
||||
let block_ncols = m.ncols();
|
||||
let max_row_with_block = r + block_nrows - 1;
|
||||
let max_col_with_block = c + block_ncols - 1;
|
||||
assert!(max_row_with_block < self.nrows);
|
||||
assert!(max_col_with_block < self.ncols);
|
||||
|
||||
self.reserve(block_ncols * block_nrows);
|
||||
|
||||
for (col_idx, col) in m.column_iter().enumerate() {
|
||||
for (row_idx, v) in col.iter().enumerate() {
|
||||
self.row_indices.push(r + row_idx);
|
||||
self.col_indices.push(c + col_idx);
|
||||
self.values.push(v.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> CooMatrix<T> {
|
||||
/// Construct a zero COO matrix of the given dimensions.
|
||||
///
|
||||
|
|
|
@ -93,6 +93,9 @@
|
|||
//! coo.push(1, 2, 1.3);
|
||||
//! coo.push(2, 2, 4.1);
|
||||
//!
|
||||
//! // ... or add entire dense matrices like so:
|
||||
//! // coo.push_matrix(0, 0, &dense);
|
||||
//!
|
||||
//! // The simplest way to construct a CSR matrix is to first construct a COO matrix, and
|
||||
//! // then convert it to CSR. The `From` trait is implemented for conversions between different
|
||||
//! // sparse matrix types.
|
||||
|
|
|
@ -252,3 +252,95 @@ fn coo_push_out_of_bounds_entries() {
|
|||
assert_panics!(coo.clone().push(3, 2, 1));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coo_push_matrix_valid_entries() {
|
||||
let mut coo = CooMatrix::new(3, 3);
|
||||
|
||||
// Works with static
|
||||
{
|
||||
// new is row-major...
|
||||
let inserted = nalgebra::SMatrix::<i32, 2, 2>::new(1, 2, 3, 4);
|
||||
coo.push_matrix(1, 1, &inserted);
|
||||
|
||||
// insert happens column-major, so expect transposition when read this way
|
||||
assert_eq!(
|
||||
coo.triplet_iter().collect::<Vec<_>>(),
|
||||
vec![(1, 1, &1), (2, 1, &3), (1, 2, &2), (2, 2, &4)]
|
||||
);
|
||||
}
|
||||
|
||||
// Works with owned dynamic
|
||||
{
|
||||
let inserted = nalgebra::DMatrix::<i32>::repeat(1, 2, 5);
|
||||
coo.push_matrix(0, 0, &inserted);
|
||||
|
||||
assert_eq!(
|
||||
coo.triplet_iter().collect::<Vec<_>>(),
|
||||
vec![
|
||||
(1, 1, &1),
|
||||
(2, 1, &3),
|
||||
(1, 2, &2),
|
||||
(2, 2, &4),
|
||||
(0, 0, &5),
|
||||
(0, 1, &5)
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Works with sliced
|
||||
{
|
||||
let source = nalgebra::SMatrix::<i32, 2, 2>::new(6, 7, 8, 9);
|
||||
let sliced = source.fixed_slice::<2, 1>(0, 0);
|
||||
coo.push_matrix(1, 0, &sliced);
|
||||
|
||||
assert_eq!(
|
||||
coo.triplet_iter().collect::<Vec<_>>(),
|
||||
vec![
|
||||
(1, 1, &1),
|
||||
(2, 1, &3),
|
||||
(1, 2, &2),
|
||||
(2, 2, &4),
|
||||
(0, 0, &5),
|
||||
(0, 1, &5),
|
||||
(1, 0, &6),
|
||||
(2, 0, &8)
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coo_push_matrix_out_of_bounds_entries() {
|
||||
// 0x0
|
||||
{
|
||||
let inserted = nalgebra::SMatrix::<i32, 1, 1>::new(1);
|
||||
assert_panics!(CooMatrix::new(0, 0).push_matrix(0, 0, &inserted));
|
||||
}
|
||||
// 0x1
|
||||
{
|
||||
let inserted = nalgebra::SMatrix::<i32, 1, 1>::new(1);
|
||||
assert_panics!(CooMatrix::new(1, 0).push_matrix(0, 0, &inserted));
|
||||
}
|
||||
// 1x0
|
||||
{
|
||||
let inserted = nalgebra::SMatrix::<i32, 1, 1>::new(1);
|
||||
assert_panics!(CooMatrix::new(0, 1).push_matrix(0, 0, &inserted));
|
||||
}
|
||||
|
||||
// 3x3 exceeds col-dim
|
||||
{
|
||||
let inserted = nalgebra::SMatrix::<i32, 1, 2>::repeat(1);
|
||||
assert_panics!(CooMatrix::new(3, 3).push_matrix(0, 2, &inserted));
|
||||
}
|
||||
// 3x3 exceeds row-dim
|
||||
{
|
||||
let inserted = nalgebra::SMatrix::<i32, 2, 1>::repeat(1);
|
||||
assert_panics!(CooMatrix::new(3, 3).push_matrix(2, 0, &inserted));
|
||||
}
|
||||
// 3x3 exceeds row-dim and row-dim
|
||||
{
|
||||
let inserted = nalgebra::SMatrix::<i32, 2, 2>::repeat(1);
|
||||
assert_panics!(CooMatrix::new(3, 3).push_matrix(2, 2, &inserted));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue