From e6e7efba8a924ca937e689f59d6984207f5b93c1 Mon Sep 17 00:00:00 2001 From: Paul Jakob Schroeder Date: Mon, 7 Jun 2021 10:13:43 -0400 Subject: [PATCH] COO: add push_matrix fn - This function allows one to add entire dense matrices to a sparse COO matrix. - Added a small mention of this new function in the example in lib.rs --- nalgebra-sparse/src/coo.rs | 37 ++++++++++ nalgebra-sparse/src/lib.rs | 3 + nalgebra-sparse/tests/unit_tests/coo.rs | 92 +++++++++++++++++++++++++ 3 files changed, 132 insertions(+) diff --git a/nalgebra-sparse/src/coo.rs b/nalgebra-sparse/src/coo.rs index caf74654..3327ed27 100644 --- a/nalgebra-sparse/src/coo.rs +++ b/nalgebra-sparse/src/coo.rs @@ -45,6 +45,43 @@ pub struct CooMatrix { values: Vec, } +impl CooMatrix { + /// Pushes a dense matrix into the sparse one. + /// + /// This adds the dense matrix `m` starting at the `r`th row and `c`th column + /// to the matrix. + /// + /// Panics + /// ------ + /// + /// Panics if any part of the dense matrix is out of bounds of the sparse matrix + /// when inserted at `(r, c)`. + #[inline] + pub fn push_matrix>( + &mut self, + r: usize, + c: usize, + m: &na::Matrix, + ) { + let block_nrows = m.nrows(); + let block_ncols = m.ncols(); + let max_row_with_block = r + block_nrows - 1; + let max_col_with_block = c + block_ncols - 1; + assert!(max_row_with_block < self.nrows); + assert!(max_col_with_block < self.ncols); + + self.reserve(block_ncols * block_nrows); + + for (col_idx, col) in m.column_iter().enumerate() { + for (row_idx, v) in col.iter().enumerate() { + self.row_indices.push(r + row_idx); + self.col_indices.push(c + col_idx); + self.values.push(v.clone()); + } + } + } +} + impl CooMatrix { /// Construct a zero COO matrix of the given dimensions. /// diff --git a/nalgebra-sparse/src/lib.rs b/nalgebra-sparse/src/lib.rs index 6f28917d..80351c07 100644 --- a/nalgebra-sparse/src/lib.rs +++ b/nalgebra-sparse/src/lib.rs @@ -93,6 +93,9 @@ //! coo.push(1, 2, 1.3); //! coo.push(2, 2, 4.1); //! +//! // ... or add entire dense matrices like so: +//! // coo.push_matrix(0, 0, &dense); +//! //! // The simplest way to construct a CSR matrix is to first construct a COO matrix, and //! // then convert it to CSR. The `From` trait is implemented for conversions between different //! // sparse matrix types. diff --git a/nalgebra-sparse/tests/unit_tests/coo.rs b/nalgebra-sparse/tests/unit_tests/coo.rs index c9fa1778..c70c5f97 100644 --- a/nalgebra-sparse/tests/unit_tests/coo.rs +++ b/nalgebra-sparse/tests/unit_tests/coo.rs @@ -252,3 +252,95 @@ fn coo_push_out_of_bounds_entries() { assert_panics!(coo.clone().push(3, 2, 1)); } } + +#[test] +fn coo_push_matrix_valid_entries() { + let mut coo = CooMatrix::new(3, 3); + + // Works with static + { + // new is row-major... + let inserted = nalgebra::SMatrix::::new(1, 2, 3, 4); + coo.push_matrix(1, 1, &inserted); + + // insert happens column-major, so expect transposition when read this way + assert_eq!( + coo.triplet_iter().collect::>(), + vec![(1, 1, &1), (2, 1, &3), (1, 2, &2), (2, 2, &4)] + ); + } + + // Works with owned dynamic + { + let inserted = nalgebra::DMatrix::::repeat(1, 2, 5); + coo.push_matrix(0, 0, &inserted); + + assert_eq!( + coo.triplet_iter().collect::>(), + vec![ + (1, 1, &1), + (2, 1, &3), + (1, 2, &2), + (2, 2, &4), + (0, 0, &5), + (0, 1, &5) + ] + ); + } + + // Works with sliced + { + let source = nalgebra::SMatrix::::new(6, 7, 8, 9); + let sliced = source.fixed_slice::<2, 1>(0, 0); + coo.push_matrix(1, 0, &sliced); + + assert_eq!( + coo.triplet_iter().collect::>(), + vec![ + (1, 1, &1), + (2, 1, &3), + (1, 2, &2), + (2, 2, &4), + (0, 0, &5), + (0, 1, &5), + (1, 0, &6), + (2, 0, &8) + ] + ); + } +} + +#[test] +fn coo_push_matrix_out_of_bounds_entries() { + // 0x0 + { + let inserted = nalgebra::SMatrix::::new(1); + assert_panics!(CooMatrix::new(0, 0).push_matrix(0, 0, &inserted)); + } + // 0x1 + { + let inserted = nalgebra::SMatrix::::new(1); + assert_panics!(CooMatrix::new(1, 0).push_matrix(0, 0, &inserted)); + } + // 1x0 + { + let inserted = nalgebra::SMatrix::::new(1); + assert_panics!(CooMatrix::new(0, 1).push_matrix(0, 0, &inserted)); + } + + // 3x3 exceeds col-dim + { + let inserted = nalgebra::SMatrix::::repeat(1); + assert_panics!(CooMatrix::new(3, 3).push_matrix(0, 2, &inserted)); + } + // 3x3 exceeds row-dim + { + let inserted = nalgebra::SMatrix::::repeat(1); + assert_panics!(CooMatrix::new(3, 3).push_matrix(2, 0, &inserted)); + } + // 3x3 exceeds row-dim and row-dim + { + let inserted = nalgebra::SMatrix::::repeat(1); + assert_panics!(CooMatrix::new(3, 3).push_matrix(2, 2, &inserted)); + } +}