Redesign error handling for CSR and SparsityPattern construction
SparsityPattern's constructor now returns a fine-grained error enum that enumerates possible errors. We use this to build a more user-friendly error when constructing CSR matrices. We also overhauled the main SparseFormatError error type by making it a struct containing a *Kind type and an underlying error that contains the message.
This commit is contained in:
parent
7e94a1539a
commit
7a5f8ef1ea
|
@ -76,13 +76,14 @@ where
|
|||
col_indices: Vec<usize>,
|
||||
values: Vec<T>,
|
||||
) -> Result<Self, SparseFormatError> {
|
||||
use crate::SparseFormatErrorKind::*;
|
||||
if row_indices.len() != col_indices.len() {
|
||||
return Err(SparseFormatError::InvalidStructure(
|
||||
Box::from("Number of row and col indices must be the same.")
|
||||
return Err(SparseFormatError::from_kind_and_msg(
|
||||
InvalidStructure, "Number of row and col indices must be the same."
|
||||
));
|
||||
} else if col_indices.len() != values.len() {
|
||||
return Err(SparseFormatError::InvalidStructure(
|
||||
Box::from("Number of col indices and values must be the same.")
|
||||
return Err(SparseFormatError::from_kind_and_msg(
|
||||
InvalidStructure, "Number of col indices and values must be the same."
|
||||
));
|
||||
}
|
||||
|
||||
|
@ -90,13 +91,9 @@ where
|
|||
let col_indices_in_bounds = col_indices.iter().all(|j| *j < ncols);
|
||||
|
||||
if !row_indices_in_bounds {
|
||||
Err(SparseFormatError::IndexOutOfBounds(Box::from(
|
||||
"Row index out of bounds.",
|
||||
)))
|
||||
Err(SparseFormatError::from_kind_and_msg(IndexOutOfBounds, "Row index out of bounds."))
|
||||
} else if !col_indices_in_bounds {
|
||||
Err(SparseFormatError::IndexOutOfBounds(Box::from(
|
||||
"Col index out of bounds.",
|
||||
)))
|
||||
Err(SparseFormatError::from_kind_and_msg(IndexOutOfBounds, "Col index out of bounds."))
|
||||
} else {
|
||||
Ok(Self {
|
||||
nrows,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{SparsityPattern, SparseFormatError};
|
||||
use crate::{SparsityPattern, SparseFormatError, SparsityPatternFormatError, SparseFormatErrorKind};
|
||||
use crate::iter::SparsityPatternIter;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
@ -92,7 +92,8 @@ impl<T> CsrMatrix<T> {
|
|||
values: Vec<T>,
|
||||
) -> Result<Self, SparseFormatError> {
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||
num_rows, num_cols, row_offsets, col_indices)?;
|
||||
num_rows, num_cols, row_offsets, col_indices)
|
||||
.map_err(pattern_format_error_to_csr_error)?;
|
||||
Self::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||
}
|
||||
|
||||
|
@ -108,8 +109,9 @@ impl<T> CsrMatrix<T> {
|
|||
values,
|
||||
})
|
||||
} else {
|
||||
return Err(SparseFormatError::InvalidStructure(
|
||||
Box::from("Number of values and column indices must be the same")));
|
||||
Err(SparseFormatError::from_kind_and_msg(
|
||||
SparseFormatErrorKind::InvalidStructure,
|
||||
"Number of values and column indices must be the same"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -264,6 +266,38 @@ impl<T: Clone + Zero> CsrMatrix<T> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Convert pattern format errors into more meaningful CSR-specific errors.
|
||||
///
|
||||
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
||||
/// not lanes, major and minor dimensions.
|
||||
fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseFormatError {
|
||||
use SparsityPatternFormatError::*;
|
||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||
use SparseFormatError as E;
|
||||
use SparseFormatErrorKind as K;
|
||||
|
||||
match err {
|
||||
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Length of row offset array is not equal to nrows + 1."),
|
||||
InvalidOffsetFirstLast => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"First or last row offset is inconsistent with format specification."),
|
||||
NonmonotonicOffsets => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Row offsets are not monotonically increasing."),
|
||||
NonmonotonicMinorIndices => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Column indices are not monotonically increasing (sorted) within each row."),
|
||||
MinorIndexOutOfBounds => E::from_kind_and_msg(
|
||||
K::IndexOutOfBounds,
|
||||
"Column indices are out of bounds."),
|
||||
PatternDuplicateEntry => E::from_kind_and_msg(
|
||||
K::DuplicateEntry,
|
||||
"Matrix data contains duplicate entries."),
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator type for iterating over triplets in a CSR matrix.
|
||||
#[derive(Debug)]
|
||||
pub struct CsrTripletIter<'a, T> {
|
||||
|
@ -360,7 +394,7 @@ macro_rules! impl_csr_row_common_methods {
|
|||
/// bounds.
|
||||
///
|
||||
/// If the index is in bounds, but no explicitly stored entry is associated with it,
|
||||
/// `T::zero()` is returned. Note that this methods offers no way of distinguishing
|
||||
/// `T::zero()` is returned. Note that this method offers no way of distinguishing
|
||||
/// explicitly stored zero entries from zero values that are only implicitly represented.
|
||||
///
|
||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||
|
|
|
@ -73,7 +73,7 @@ pub mod ops;
|
|||
|
||||
pub use coo::CooMatrix;
|
||||
pub use csr::{CsrMatrix, CsrRow, CsrRowMut};
|
||||
pub use pattern::{SparsityPattern};
|
||||
pub use pattern::{SparsityPattern, SparsityPatternFormatError};
|
||||
|
||||
/// Iterator types for matrices.
|
||||
///
|
||||
|
@ -94,31 +94,53 @@ use std::fmt;
|
|||
|
||||
/// Errors produced by functions that expect well-formed sparse format data.
|
||||
#[derive(Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum SparseFormatError {
|
||||
pub struct SparseFormatError {
|
||||
kind: SparseFormatErrorKind,
|
||||
// Currently we only use an underlying error for generating the `Display` impl
|
||||
error: Box<dyn Error>
|
||||
}
|
||||
|
||||
impl SparseFormatError {
|
||||
/// The type of error.
|
||||
pub fn kind(&self) -> &SparseFormatErrorKind {
|
||||
&self.kind
|
||||
}
|
||||
|
||||
pub(crate) fn from_kind_and_error(kind: SparseFormatErrorKind, error: Box<dyn Error>) -> Self {
|
||||
Self {
|
||||
kind,
|
||||
error
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper functionality for more conveniently creating errors.
|
||||
pub(crate) fn from_kind_and_msg(kind: SparseFormatErrorKind, msg: &'static str) -> Self {
|
||||
Self::from_kind_and_error(kind, Box::<dyn Error>::from(msg))
|
||||
}
|
||||
}
|
||||
|
||||
/// The type of format error described by a [SparseFormatError](struct.SparseFormatError.html).
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SparseFormatErrorKind {
|
||||
/// Indicates that the index data associated with the format contains at least one index
|
||||
/// out of bounds.
|
||||
IndexOutOfBounds(Box<dyn Error>),
|
||||
IndexOutOfBounds,
|
||||
|
||||
/// Indicates that the provided data contains at least one duplicate entry, and the
|
||||
/// current format does not support duplicate entries.
|
||||
DuplicateEntry(Box<dyn Error>),
|
||||
DuplicateEntry,
|
||||
|
||||
/// Indicates that the provided data for the format does not conform to the high-level
|
||||
/// structure of the format.
|
||||
///
|
||||
/// For example, the arrays defining the format data might have incompatible sizes.
|
||||
InvalidStructure(Box<dyn Error>),
|
||||
InvalidStructure,
|
||||
}
|
||||
|
||||
impl fmt::Display for SparseFormatError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Self::IndexOutOfBounds(err) => err.fmt(f),
|
||||
Self::DuplicateEntry(err) => err.fmt(f),
|
||||
Self::InvalidStructure(err) => err.fmt(f)
|
||||
}
|
||||
write!(f, "{}", self.error)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for SparseFormatError {}
|
||||
impl Error for SparseFormatError {}
|
|
@ -1,4 +1,6 @@
|
|||
use crate::SparseFormatError;
|
||||
use std::fmt;
|
||||
use std::error::Error;
|
||||
|
||||
/// A representation of the sparsity pattern of a CSR or CSC matrix.
|
||||
///
|
||||
|
@ -79,27 +81,24 @@ impl SparsityPattern {
|
|||
minor_dim: usize,
|
||||
major_offsets: Vec<usize>,
|
||||
minor_indices: Vec<usize>,
|
||||
) -> Result<Self, SparseFormatError> {
|
||||
) -> Result<Self, SparsityPatternFormatError> {
|
||||
// TODO: If these errors are *directly* propagated to errors from e.g.
|
||||
// CSR construction, the error messages will be confusing to users,
|
||||
// as the error messages refer to "major" and "minor" lanes, as opposed to
|
||||
// rows and columns
|
||||
|
||||
use SparsityPatternFormatError::*;
|
||||
|
||||
if major_offsets.len() != major_dim + 1 {
|
||||
return Err(SparseFormatError::InvalidStructure(
|
||||
Box::from("Size of major_offsets must be equal to (major_dim + 1)")));
|
||||
return Err(InvalidOffsetArrayLength);
|
||||
}
|
||||
|
||||
// Check that the first and last offsets conform to the specification
|
||||
{
|
||||
if *major_offsets.first().unwrap() != 0 {
|
||||
return Err(SparseFormatError::InvalidStructure(
|
||||
Box::from("First entry in major_offsets must always be 0.")
|
||||
));
|
||||
} else if *major_offsets.last().unwrap() != minor_indices.len() {
|
||||
return Err(SparseFormatError::InvalidStructure(
|
||||
Box::from("Last entry in major_offsets must always be equal to minor_indices.len()")
|
||||
));
|
||||
let first_offset_ok = *major_offsets.first().unwrap() == 0;
|
||||
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
|
||||
if !first_offset_ok || !last_offset_ok {
|
||||
return Err(InvalidOffsetFirstLast);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -113,9 +112,7 @@ impl SparsityPattern {
|
|||
|
||||
// Test that major offsets are monotonically increasing
|
||||
if range_start > range_end {
|
||||
return Err(SparseFormatError::InvalidStructure(
|
||||
Box::from("Major offsets are not monotonically increasing.")
|
||||
));
|
||||
return Err(NonmonotonicOffsets);
|
||||
}
|
||||
|
||||
let minor_indices = &minor_indices[range_start .. range_end];
|
||||
|
@ -127,20 +124,14 @@ impl SparsityPattern {
|
|||
|
||||
while let Some(next) = iter.next().copied() {
|
||||
if next > minor_dim {
|
||||
return Err(SparseFormatError::IndexOutOfBounds(
|
||||
Box::from("Minor index out of bounds.")
|
||||
));
|
||||
return Err(MinorIndexOutOfBounds);
|
||||
}
|
||||
|
||||
if let Some(prev) = prev {
|
||||
if prev > next {
|
||||
return Err(SparseFormatError::InvalidStructure(
|
||||
Box::from("Minor indices within a lane must be monotonically increasing (sorted).")
|
||||
));
|
||||
return Err(NonmonotonicMinorIndices);
|
||||
} else if prev == next {
|
||||
return Err(SparseFormatError::DuplicateEntry(
|
||||
Box::from("Duplicate minor entries detected.")
|
||||
));
|
||||
return Err(DuplicateEntry);
|
||||
}
|
||||
}
|
||||
prev = Some(next);
|
||||
|
@ -180,6 +171,82 @@ impl SparsityPattern {
|
|||
}
|
||||
}
|
||||
|
||||
/// Error type for `SparsityPattern` format errors.
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug)]
|
||||
pub enum SparsityPatternFormatError {
|
||||
/// Indicates an invalid number of offsets.
|
||||
///
|
||||
/// The number of offsets must be equal to (major_dim + 1).
|
||||
InvalidOffsetArrayLength,
|
||||
/// Indicates that the first or last entry in the offset array did not conform to
|
||||
/// specifications.
|
||||
///
|
||||
/// The first entry must be 0, and the last entry must be exactly one greater than the
|
||||
/// major dimension.
|
||||
InvalidOffsetFirstLast,
|
||||
/// Indicates that the major offsets are not monotonically increasing.
|
||||
NonmonotonicOffsets,
|
||||
/// One or more minor indices are out of bounds.
|
||||
MinorIndexOutOfBounds,
|
||||
/// One or more duplicate entries were detected.
|
||||
///
|
||||
/// Two entries are considered duplicates if they are part of the same major lane and have
|
||||
/// the same minor index.
|
||||
DuplicateEntry,
|
||||
/// Indicates that minor indices are not monotonically increasing within each lane.
|
||||
NonmonotonicMinorIndices,
|
||||
}
|
||||
|
||||
impl From<SparsityPatternFormatError> for SparseFormatError {
|
||||
fn from(err: SparsityPatternFormatError) -> Self {
|
||||
use SparsityPatternFormatError::*;
|
||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||
use crate::SparseFormatErrorKind;
|
||||
use crate::SparseFormatErrorKind::*;
|
||||
match err {
|
||||
InvalidOffsetArrayLength
|
||||
| InvalidOffsetFirstLast
|
||||
| NonmonotonicOffsets
|
||||
| NonmonotonicMinorIndices
|
||||
=> SparseFormatError::from_kind_and_error(InvalidStructure, Box::from(err)),
|
||||
MinorIndexOutOfBounds
|
||||
=> SparseFormatError::from_kind_and_error(IndexOutOfBounds,
|
||||
Box::from(err)),
|
||||
PatternDuplicateEntry
|
||||
=> SparseFormatError::from_kind_and_error(SparseFormatErrorKind::DuplicateEntry,
|
||||
Box::from(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SparsityPatternFormatError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
SparsityPatternFormatError::InvalidOffsetArrayLength => {
|
||||
write!(f, "Length of offset array is not equal to (major_dim + 1).")
|
||||
},
|
||||
SparsityPatternFormatError::InvalidOffsetFirstLast => {
|
||||
write!(f, "First or last offset is incompatible with format.")
|
||||
},
|
||||
SparsityPatternFormatError::NonmonotonicOffsets => {
|
||||
write!(f, "Offsets are not monotonically increasing.")
|
||||
},
|
||||
SparsityPatternFormatError::MinorIndexOutOfBounds => {
|
||||
write!(f, "A minor index is out of bounds.")
|
||||
},
|
||||
SparsityPatternFormatError::DuplicateEntry => {
|
||||
write!(f, "Input data contains duplicate entries.")
|
||||
},
|
||||
SparsityPatternFormatError::NonmonotonicMinorIndices => {
|
||||
write!(f, "Minor indices are not monotonically increasing within each lane.")
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for SparsityPatternFormatError {}
|
||||
|
||||
/// Iterator type for iterating over entries in a sparsity pattern.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SparsityPatternIter<'a> {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use nalgebra_sparse::{CooMatrix, SparseFormatError};
|
||||
use nalgebra_sparse::{CooMatrix, SparseFormatErrorKind};
|
||||
use nalgebra::DMatrix;
|
||||
use crate::assert_panics;
|
||||
|
||||
|
@ -91,25 +91,25 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
|
|||
{
|
||||
// 0x0 matrix
|
||||
let result = CooMatrix::<i32>::try_from_triplets(0, 0, vec![0], vec![0], vec![2]);
|
||||
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
|
||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
||||
}
|
||||
|
||||
{
|
||||
// 1x1 matrix, row out of bounds
|
||||
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![0], vec![2]);
|
||||
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
|
||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
||||
}
|
||||
|
||||
{
|
||||
// 1x1 matrix, col out of bounds
|
||||
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![0], vec![1], vec![2]);
|
||||
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
|
||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
||||
}
|
||||
|
||||
{
|
||||
// 1x1 matrix, row and col out of bounds
|
||||
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![1], vec![2]);
|
||||
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
|
||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -118,7 +118,7 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
|
|||
let j = vec![0, 2, 1, 3, 3];
|
||||
let v = vec![2, 3, 7, 3, 1];
|
||||
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
||||
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
|
||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -127,7 +127,7 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
|
|||
let j = vec![0, 2, 1, 5, 3];
|
||||
let v = vec![2, 3, 7, 3, 1];
|
||||
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
||||
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
|
||||
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -136,7 +136,7 @@ fn coo_try_from_triplets_panics_on_mismatched_vectors() {
|
|||
// Check that try_from_triplets panics when the triplet vectors have different lengths
|
||||
macro_rules! assert_errs {
|
||||
($result:expr) => {
|
||||
assert!(matches!($result, Err(SparseFormatError::InvalidStructure(_))))
|
||||
assert!(matches!($result.unwrap_err().kind(), SparseFormatErrorKind::InvalidStructure))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue