diff --git a/nalgebra-sparse/src/cs.rs b/nalgebra-sparse/src/cs.rs index edcb7fa3..c1b4b449 100644 --- a/nalgebra-sparse/src/cs.rs +++ b/nalgebra-sparse/src/cs.rs @@ -4,6 +4,8 @@ use crate::{SparseEntry, SparseEntryMut}; use std::sync::Arc; use std::ops::Range; use std::mem::replace; +use num_traits::One; +use nalgebra::Scalar; /// An abstract compressed matrix. /// @@ -156,6 +158,21 @@ impl CsMatrix { } } +impl CsMatrix { + /// TODO + #[inline] + pub fn identity(n: usize) -> Self { + let offsets: Vec<_> = (0 ..= n).collect(); + let indices: Vec<_> = (0 .. n).collect(); + let values = vec![T::one(); n]; + + // TODO: We should skip checks here + let pattern = SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices) + .unwrap(); + Self::from_pattern_and_values(Arc::new(pattern), values) + } +} + fn get_entry_from_slices<'a, T>( minor_dim: usize, minor_indices: &'a [usize], diff --git a/nalgebra-sparse/src/csc.rs b/nalgebra-sparse/src/csc.rs index 7b3b8c10..f77bfaef 100644 --- a/nalgebra-sparse/src/csc.rs +++ b/nalgebra-sparse/src/csc.rs @@ -7,7 +7,7 @@ use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut}; use std::sync::Arc; use std::slice::{IterMut, Iter}; -use num_traits::Zero; +use num_traits::{Zero, One}; use nalgebra::Scalar; /// A CSC representation of a sparse matrix. @@ -337,6 +337,16 @@ impl CscMatrix } } +impl CscMatrix { + /// TODO + #[inline] + pub fn identity(n: usize) -> Self { + Self { + cs: CsMatrix::identity(n) + } + } +} + /// Convert pattern format errors into more meaningful CSC-specific errors. /// /// This ensures that the terminology is consistent: we are talking about rows and columns, diff --git a/nalgebra-sparse/src/csr.rs b/nalgebra-sparse/src/csr.rs index d5b8e92e..ca42492c 100644 --- a/nalgebra-sparse/src/csr.rs +++ b/nalgebra-sparse/src/csr.rs @@ -5,7 +5,7 @@ use crate::csc::CscMatrix; use crate::cs::{CsMatrix, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut}; use nalgebra::Scalar; -use num_traits::Zero; +use num_traits::{Zero, One}; use std::sync::Arc; use std::slice::{IterMut, Iter}; @@ -338,6 +338,16 @@ where } } +impl CsrMatrix { + /// TODO + #[inline] + pub fn identity(n: usize) -> Self { + Self { + cs: CsMatrix::identity(n) + } + } +} + /// Convert pattern format errors into more meaningful CSR-specific errors. /// /// This ensures that the terminology is consistent: we are talking about rows and columns,