Verify data validity in try_* constructors
We can easily create the CSR and CSC constructors by using the SparsityPattern constructors. However, one lingering problem is giving meaningful error messages. When forwarding error messages from the SparsityPattern constructor, the error messages must be written in terms of "major" or "minor" dimensions, and using general terms like "lanes", instead of "rows" and "columns". When forwarding these messages up to CSR or CSC constructors, they are not directly meaningful to an end user. We should find a better solution to this problem, so that end users get more meaningful messages.
This commit is contained in:
parent
7f5b702a49
commit
b1199da206
|
@ -82,9 +82,8 @@ impl<T> CsrMatrix<T> {
|
|||
/// bounds with respect to the number of columns in the matrix. If this is not the case,
|
||||
/// an error is returned to indicate the failure.
|
||||
///
|
||||
/// Panics
|
||||
/// ------
|
||||
/// Panics if the lengths of the provided arrays are not compatible with the CSR format.
|
||||
/// An error is returned if the data given does not conform to the CSR storage format.
|
||||
/// See the documentation for [CsrMatrix](struct.CsrMatrix.html) for more information.
|
||||
pub fn try_from_csr_data(
|
||||
num_rows: usize,
|
||||
num_cols: usize,
|
||||
|
@ -92,16 +91,29 @@ impl<T> CsrMatrix<T> {
|
|||
col_indices: Vec<usize>,
|
||||
values: Vec<T>,
|
||||
) -> Result<Self, SparseFormatError> {
|
||||
assert_eq!(col_indices.len(), values.len(),
|
||||
"Number of values and column indices must be the same");
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||
num_rows, num_cols, row_offsets, col_indices)?;
|
||||
Ok(Self {
|
||||
sparsity_pattern: Arc::new(pattern),
|
||||
values,
|
||||
})
|
||||
Self::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||
}
|
||||
|
||||
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
|
||||
///
|
||||
/// Returns an error if the number of values does not match the number of minor indices
|
||||
/// in the pattern.
|
||||
pub fn try_from_pattern_and_values(pattern: Arc<SparsityPattern>, values: Vec<T>)
|
||||
-> Result<Self, SparseFormatError> {
|
||||
if pattern.nnz() == values.len() {
|
||||
Ok(Self {
|
||||
sparsity_pattern: pattern,
|
||||
values,
|
||||
})
|
||||
} else {
|
||||
return Err(SparseFormatError::InvalidStructure(
|
||||
Box::from("Number of values and column indices must be the same")));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// An iterator over non-zero triplets (i, j, v).
|
||||
///
|
||||
/// The iteration happens in row-major fashion, meaning that i increases monotonically,
|
||||
|
|
|
@ -1,6 +1,17 @@
|
|||
use crate::SparseFormatError;
|
||||
|
||||
/// A representation of the sparsity pattern of a CSR or COO matrix.
|
||||
/// A representation of the sparsity pattern of a CSR or CSC matrix.
|
||||
///
|
||||
/// ## Format specification
|
||||
///
|
||||
/// TODO: Write this out properly
|
||||
///
|
||||
/// - offsets[0] == 0
|
||||
/// - Major offsets must be monotonically increasing
|
||||
/// - major_offsets.len() == major_dim + 1
|
||||
/// - Column indices within each lane must be sorted
|
||||
/// - Column indices must be in-bounds
|
||||
/// - The last entry in major offsets must correspond to the number of minor indices
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
// TODO: Make SparsityPattern parametrized by index type
|
||||
// (need a solid abstraction for index types though)
|
||||
|
@ -63,16 +74,80 @@ impl SparsityPattern {
|
|||
/// and minor indices.
|
||||
///
|
||||
/// Returns an error if the data does not conform to the requirements.
|
||||
///
|
||||
/// TODO: Maybe we should not do any assertions in any of the construction functions
|
||||
pub fn try_from_offsets_and_indices(
|
||||
major_dim: usize,
|
||||
minor_dim: usize,
|
||||
major_offsets: Vec<usize>,
|
||||
minor_indices: Vec<usize>,
|
||||
) -> Result<Self, SparseFormatError> {
|
||||
assert_eq!(major_offsets.len(), major_dim + 1);
|
||||
assert_eq!(*major_offsets.last().unwrap(), minor_indices.len());
|
||||
// TODO: If these errors are *directly* propagated to errors from e.g.
|
||||
// CSR construction, the error messages will be confusing to users,
|
||||
// as the error messages refer to "major" and "minor" lanes, as opposed to
|
||||
// rows and columns
|
||||
|
||||
if major_offsets.len() != major_dim + 1 {
|
||||
return Err(SparseFormatError::InvalidStructure(
|
||||
Box::from("Size of major_offsets must be equal to (major_dim + 1)")));
|
||||
}
|
||||
|
||||
// Check that the first and last offsets conform to the specification
|
||||
{
|
||||
if *major_offsets.first().unwrap() != 0 {
|
||||
return Err(SparseFormatError::InvalidStructure(
|
||||
Box::from("First entry in major_offsets must always be 0.")
|
||||
));
|
||||
} else if *major_offsets.last().unwrap() != minor_indices.len() {
|
||||
return Err(SparseFormatError::InvalidStructure(
|
||||
Box::from("Last entry in major_offsets must always be equal to minor_indices.len()")
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Test that each lane has strictly monotonically increasing minor indices, i.e.
|
||||
// minor indices within a lane are sorted, unique. In addition, each minor index
|
||||
// must be in bounds with respect to the minor dimension.
|
||||
{
|
||||
for lane_idx in 0 .. major_dim {
|
||||
let range_start = major_offsets[lane_idx];
|
||||
let range_end = major_offsets[lane_idx + 1];
|
||||
|
||||
// Test that major offsets are monotonically increasing
|
||||
if range_start > range_end {
|
||||
return Err(SparseFormatError::InvalidStructure(
|
||||
Box::from("Major offsets are not monotonically increasing.")
|
||||
));
|
||||
}
|
||||
|
||||
let minor_indices = &minor_indices[range_start .. range_end];
|
||||
|
||||
// We test for in-bounds, uniqueness and monotonicity at the same time
|
||||
// to ensure that we only visit each minor index once
|
||||
let mut iter = minor_indices.iter();
|
||||
let mut prev = None;
|
||||
|
||||
while let Some(next) = iter.next().copied() {
|
||||
if next > minor_dim {
|
||||
return Err(SparseFormatError::IndexOutOfBounds(
|
||||
Box::from("Minor index out of bounds.")
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(prev) = prev {
|
||||
if prev > next {
|
||||
return Err(SparseFormatError::InvalidStructure(
|
||||
Box::from("Minor indices within a lane must be monotonically increasing (sorted).")
|
||||
));
|
||||
} else if prev == next {
|
||||
return Err(SparseFormatError::DuplicateEntry(
|
||||
Box::from("Duplicate minor entries detected.")
|
||||
));
|
||||
}
|
||||
}
|
||||
prev = Some(next);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
major_offsets,
|
||||
minor_indices,
|
||||
|
@ -83,7 +158,7 @@ impl SparsityPattern {
|
|||
/// An iterator over the explicitly stored "non-zero" entries (i, j).
|
||||
///
|
||||
/// The iteration happens in a lane-major fashion, meaning that the lane index i
|
||||
/// increases monotonically. and the minor index j increases monotonically within each
|
||||
/// increases monotonically, and the minor index j increases monotonically within each
|
||||
/// lane i.
|
||||
///
|
||||
/// Examples
|
||||
|
|
Loading…
Reference in New Issue