Move COO, CSC, CSR constructor at the top of the impls.

This commit is contained in:
Crozet Sébastien 2021-02-25 11:11:29 +01:00
parent 98ae4f3818
commit c6f7cae326
3 changed files with 129 additions and 122 deletions

View File

@ -61,6 +61,15 @@ impl<T> CooMatrix<T> {
} }
} }
/// Construct a zero COO matrix of the given dimensions.
///
/// Specifically, the collection of triplets - corresponding to explicitly stored entries -
/// is empty, so that the matrix (implicitly) represented by the COO matrix consists of all
/// zero entries.
pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self::new(nrows, ncols)
}
/// Try to construct a COO matrix from the given dimensions and a collection of /// Try to construct a COO matrix from the given dimensions and a collection of
/// (i, j, v) triplets. /// (i, j, v) triplets.
/// ///

View File

@ -127,6 +127,17 @@ pub struct CscMatrix<T> {
} }
impl<T> CscMatrix<T> { impl<T> CscMatrix<T> {
/// Constructs a CSC representation of the (square) `n x n` identity matrix.
#[inline]
pub fn identity(n: usize) -> Self
where
T: Scalar + One,
{
Self {
cs: CsMatrix::identity(n),
}
}
/// Create a zero CSC matrix with no explicitly stored entries. /// Create a zero CSC matrix with no explicitly stored entries.
pub fn zeros(nrows: usize, ncols: usize) -> Self { pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self { Self {
@ -134,6 +145,51 @@ impl<T> CscMatrix<T> {
} }
} }
/// Try to construct a CSC matrix from raw CSC data.
///
/// It is assumed that each column contains unique and sorted row indices that are in
/// bounds with respect to the number of rows in the matrix. If this is not the case,
/// an error is returned to indicate the failure.
///
/// An error is returned if the data given does not conform to the CSC storage format.
/// See the documentation for [CscMatrix](struct.CscMatrix.html) for more information.
pub fn try_from_csc_data(
num_rows: usize,
num_cols: usize,
col_offsets: Vec<usize>,
row_indices: Vec<usize>,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
let pattern = SparsityPattern::try_from_offsets_and_indices(
num_cols,
num_rows,
col_offsets,
row_indices,
)
.map_err(pattern_format_error_to_csc_error)?;
Self::try_from_pattern_and_values(pattern, values)
}
/// Try to construct a CSC matrix from a sparsity pattern and associated non-zero values.
///
/// Returns an error if the number of values does not match the number of minor indices
/// in the pattern.
pub fn try_from_pattern_and_values(
pattern: SparsityPattern,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() {
Ok(Self {
cs: CsMatrix::from_pattern_and_values(pattern, values),
})
} else {
Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"Number of values and row indices must be the same",
))
}
}
/// The number of rows in the matrix. /// The number of rows in the matrix.
#[inline] #[inline]
pub fn nrows(&self) -> usize { pub fn nrows(&self) -> usize {
@ -180,51 +236,6 @@ impl<T> CscMatrix<T> {
self.cs.values_mut() self.cs.values_mut()
} }
/// Try to construct a CSC matrix from raw CSC data.
///
/// It is assumed that each column contains unique and sorted row indices that are in
/// bounds with respect to the number of rows in the matrix. If this is not the case,
/// an error is returned to indicate the failure.
///
/// An error is returned if the data given does not conform to the CSC storage format.
/// See the documentation for [CscMatrix](struct.CscMatrix.html) for more information.
pub fn try_from_csc_data(
num_rows: usize,
num_cols: usize,
col_offsets: Vec<usize>,
row_indices: Vec<usize>,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
let pattern = SparsityPattern::try_from_offsets_and_indices(
num_cols,
num_rows,
col_offsets,
row_indices,
)
.map_err(pattern_format_error_to_csc_error)?;
Self::try_from_pattern_and_values(pattern, values)
}
/// Try to construct a CSC matrix from a sparsity pattern and associated non-zero values.
///
/// Returns an error if the number of values does not match the number of minor indices
/// in the pattern.
pub fn try_from_pattern_and_values(
pattern: SparsityPattern,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() {
Ok(Self {
cs: CsMatrix::from_pattern_and_values(pattern, values),
})
} else {
Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"Number of values and row indices must be the same",
))
}
}
/// An iterator over non-zero triplets (i, j, v). /// An iterator over non-zero triplets (i, j, v).
/// ///
/// The iteration happens in column-major fashion, meaning that j increases monotonically, /// The iteration happens in column-major fashion, meaning that j increases monotonically,
@ -485,28 +496,16 @@ impl<T> CscMatrix<T> {
cs: self.cs.diagonal_as_matrix(), cs: self.cs.diagonal_as_matrix(),
} }
} }
}
impl<T> CscMatrix<T>
where
T: Scalar,
{
/// Compute the transpose of the matrix. /// Compute the transpose of the matrix.
pub fn transpose(&self) -> CscMatrix<T> { pub fn transpose(&self) -> CscMatrix<T>
where
T: Scalar,
{
CsrMatrix::from(self).transpose_as_csc() CsrMatrix::from(self).transpose_as_csc()
} }
} }
impl<T: Scalar + One> CscMatrix<T> {
/// Constructs a CSC representation of the (square) `n x n` identity matrix.
#[inline]
pub fn identity(n: usize) -> Self {
Self {
cs: CsMatrix::identity(n),
}
}
}
/// Convert pattern format errors into more meaningful CSC-specific errors. /// Convert pattern format errors into more meaningful CSC-specific errors.
/// ///
/// This ensures that the terminology is consistent: we are talking about rows and columns, /// This ensures that the terminology is consistent: we are talking about rows and columns,

View File

@ -127,6 +127,17 @@ pub struct CsrMatrix<T> {
} }
impl<T> CsrMatrix<T> { impl<T> CsrMatrix<T> {
/// Constructs a CSR representation of the (square) `n x n` identity matrix.
#[inline]
pub fn identity(n: usize) -> Self
where
T: Scalar + One,
{
Self {
cs: CsMatrix::identity(n),
}
}
/// Create a zero CSR matrix with no explicitly stored entries. /// Create a zero CSR matrix with no explicitly stored entries.
pub fn zeros(nrows: usize, ncols: usize) -> Self { pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self { Self {
@ -134,6 +145,51 @@ impl<T> CsrMatrix<T> {
} }
} }
/// Try to construct a CSR matrix from raw CSR data.
///
/// It is assumed that each row contains unique and sorted column indices that are in
/// bounds with respect to the number of columns in the matrix. If this is not the case,
/// an error is returned to indicate the failure.
///
/// An error is returned if the data given does not conform to the CSR storage format.
/// See the documentation for [CsrMatrix](struct.CsrMatrix.html) for more information.
pub fn try_from_csr_data(
num_rows: usize,
num_cols: usize,
row_offsets: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
let pattern = SparsityPattern::try_from_offsets_and_indices(
num_rows,
num_cols,
row_offsets,
col_indices,
)
.map_err(pattern_format_error_to_csr_error)?;
Self::try_from_pattern_and_values(pattern, values)
}
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
///
/// Returns an error if the number of values does not match the number of minor indices
/// in the pattern.
pub fn try_from_pattern_and_values(
pattern: SparsityPattern,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() {
Ok(Self {
cs: CsMatrix::from_pattern_and_values(pattern, values),
})
} else {
Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"Number of values and column indices must be the same",
))
}
}
/// The number of rows in the matrix. /// The number of rows in the matrix.
#[inline] #[inline]
pub fn nrows(&self) -> usize { pub fn nrows(&self) -> usize {
@ -182,51 +238,6 @@ impl<T> CsrMatrix<T> {
self.cs.values_mut() self.cs.values_mut()
} }
/// Try to construct a CSR matrix from raw CSR data.
///
/// It is assumed that each row contains unique and sorted column indices that are in
/// bounds with respect to the number of columns in the matrix. If this is not the case,
/// an error is returned to indicate the failure.
///
/// An error is returned if the data given does not conform to the CSR storage format.
/// See the documentation for [CsrMatrix](struct.CsrMatrix.html) for more information.
pub fn try_from_csr_data(
num_rows: usize,
num_cols: usize,
row_offsets: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
let pattern = SparsityPattern::try_from_offsets_and_indices(
num_rows,
num_cols,
row_offsets,
col_indices,
)
.map_err(pattern_format_error_to_csr_error)?;
Self::try_from_pattern_and_values(pattern, values)
}
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
///
/// Returns an error if the number of values does not match the number of minor indices
/// in the pattern.
pub fn try_from_pattern_and_values(
pattern: SparsityPattern,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() {
Ok(Self {
cs: CsMatrix::from_pattern_and_values(pattern, values),
})
} else {
Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"Number of values and column indices must be the same",
))
}
}
/// An iterator over non-zero triplets (i, j, v). /// An iterator over non-zero triplets (i, j, v).
/// ///
/// The iteration happens in row-major fashion, meaning that i increases monotonically, /// The iteration happens in row-major fashion, meaning that i increases monotonically,
@ -485,28 +496,16 @@ impl<T> CsrMatrix<T> {
cs: self.cs.diagonal_as_matrix(), cs: self.cs.diagonal_as_matrix(),
} }
} }
}
impl<T> CsrMatrix<T>
where
T: Scalar,
{
/// Compute the transpose of the matrix. /// Compute the transpose of the matrix.
pub fn transpose(&self) -> CsrMatrix<T> { pub fn transpose(&self) -> CsrMatrix<T>
where
T: Scalar,
{
CscMatrix::from(self).transpose_as_csr() CscMatrix::from(self).transpose_as_csr()
} }
} }
impl<T: Scalar + One> CsrMatrix<T> {
/// Constructs a CSR representation of the (square) `n x n` identity matrix.
#[inline]
pub fn identity(n: usize) -> Self {
Self {
cs: CsMatrix::identity(n),
}
}
}
/// Convert pattern format errors into more meaningful CSR-specific errors. /// Convert pattern format errors into more meaningful CSR-specific errors.
/// ///
/// This ensures that the terminology is consistent: we are talking about rows and columns, /// This ensures that the terminology is consistent: we are talking about rows and columns,