Implement Csr/CscMatrix::identity

This commit is contained in:
Andreas Longva 2021-01-11 15:03:58 +01:00
parent ea6c1451b4
commit 6e34c23d05
3 changed files with 39 additions and 2 deletions

View File

@ -4,6 +4,8 @@ use crate::{SparseEntry, SparseEntryMut};
use std::sync::Arc; use std::sync::Arc;
use std::ops::Range; use std::ops::Range;
use std::mem::replace; use std::mem::replace;
use num_traits::One;
use nalgebra::Scalar;
/// An abstract compressed matrix. /// An abstract compressed matrix.
/// ///
@ -156,6 +158,21 @@ impl<T> CsMatrix<T> {
} }
} }
impl<T: Scalar + One> CsMatrix<T> {
/// TODO
#[inline]
pub fn identity(n: usize) -> Self {
let offsets: Vec<_> = (0 ..= n).collect();
let indices: Vec<_> = (0 .. n).collect();
let values = vec![T::one(); n];
// TODO: We should skip checks here
let pattern = SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices)
.unwrap();
Self::from_pattern_and_values(Arc::new(pattern), values)
}
}
fn get_entry_from_slices<'a, T>( fn get_entry_from_slices<'a, T>(
minor_dim: usize, minor_dim: usize,
minor_indices: &'a [usize], minor_indices: &'a [usize],

View File

@ -7,7 +7,7 @@ use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut};
use std::sync::Arc; use std::sync::Arc;
use std::slice::{IterMut, Iter}; use std::slice::{IterMut, Iter};
use num_traits::Zero; use num_traits::{Zero, One};
use nalgebra::Scalar; use nalgebra::Scalar;
/// A CSC representation of a sparse matrix. /// A CSC representation of a sparse matrix.
@ -337,6 +337,16 @@ impl<T> CscMatrix<T>
} }
} }
impl<T: Scalar + One> CscMatrix<T> {
/// TODO
#[inline]
pub fn identity(n: usize) -> Self {
Self {
cs: CsMatrix::identity(n)
}
}
}
/// Convert pattern format errors into more meaningful CSC-specific errors. /// Convert pattern format errors into more meaningful CSC-specific errors.
/// ///
/// This ensures that the terminology is consistent: we are talking about rows and columns, /// This ensures that the terminology is consistent: we are talking about rows and columns,

View File

@ -5,7 +5,7 @@ use crate::csc::CscMatrix;
use crate::cs::{CsMatrix, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut}; use crate::cs::{CsMatrix, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut};
use nalgebra::Scalar; use nalgebra::Scalar;
use num_traits::Zero; use num_traits::{Zero, One};
use std::sync::Arc; use std::sync::Arc;
use std::slice::{IterMut, Iter}; use std::slice::{IterMut, Iter};
@ -338,6 +338,16 @@ where
} }
} }
impl<T: Scalar + One> CsrMatrix<T> {
/// TODO
#[inline]
pub fn identity(n: usize) -> Self {
Self {
cs: CsMatrix::identity(n)
}
}
}
/// Convert pattern format errors into more meaningful CSR-specific errors. /// Convert pattern format errors into more meaningful CSR-specific errors.
/// ///
/// This ensures that the terminology is consistent: we are talking about rows and columns, /// This ensures that the terminology is consistent: we are talking about rows and columns,