Redesign error handling for CSR and SparsityPattern construction

SparsityPattern's constructor now returns a fine-grained error
enum that enumerates possible errors. We use this to build a more
user-friendly error when constructing CSR matrices.

We also overhauled the main SparseFormatError error type by
making it a struct containing a *Kind type and an underlying error
that contains the message.
This commit is contained in:
Andreas Longva 2020-09-22 17:50:47 +02:00
parent 7e94a1539a
commit 7a5f8ef1ea
5 changed files with 178 additions and 58 deletions

View File

@ -76,13 +76,14 @@ where
col_indices: Vec<usize>,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
use crate::SparseFormatErrorKind::*;
if row_indices.len() != col_indices.len() {
return Err(SparseFormatError::InvalidStructure(
Box::from("Number of row and col indices must be the same.")
return Err(SparseFormatError::from_kind_and_msg(
InvalidStructure, "Number of row and col indices must be the same."
));
} else if col_indices.len() != values.len() {
return Err(SparseFormatError::InvalidStructure(
Box::from("Number of col indices and values must be the same.")
return Err(SparseFormatError::from_kind_and_msg(
InvalidStructure, "Number of col indices and values must be the same."
));
}
@ -90,13 +91,9 @@ where
let col_indices_in_bounds = col_indices.iter().all(|j| *j < ncols);
if !row_indices_in_bounds {
Err(SparseFormatError::IndexOutOfBounds(Box::from(
"Row index out of bounds.",
)))
Err(SparseFormatError::from_kind_and_msg(IndexOutOfBounds, "Row index out of bounds."))
} else if !col_indices_in_bounds {
Err(SparseFormatError::IndexOutOfBounds(Box::from(
"Col index out of bounds.",
)))
Err(SparseFormatError::from_kind_and_msg(IndexOutOfBounds, "Col index out of bounds."))
} else {
Ok(Self {
nrows,

View File

@ -1,4 +1,4 @@
use crate::{SparsityPattern, SparseFormatError};
use crate::{SparsityPattern, SparseFormatError, SparsityPatternFormatError, SparseFormatErrorKind};
use crate::iter::SparsityPatternIter;
use std::sync::Arc;
@ -92,7 +92,8 @@ impl<T> CsrMatrix<T> {
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
let pattern = SparsityPattern::try_from_offsets_and_indices(
num_rows, num_cols, row_offsets, col_indices)?;
num_rows, num_cols, row_offsets, col_indices)
.map_err(pattern_format_error_to_csr_error)?;
Self::try_from_pattern_and_values(Arc::new(pattern), values)
}
@ -108,8 +109,9 @@ impl<T> CsrMatrix<T> {
values,
})
} else {
return Err(SparseFormatError::InvalidStructure(
Box::from("Number of values and column indices must be the same")));
Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"Number of values and column indices must be the same"))
}
}
@ -264,6 +266,38 @@ impl<T: Clone + Zero> CsrMatrix<T> {
}
}
/// Convert pattern format errors into more meaningful CSR-specific errors.
///
/// This ensures that the terminology is consistent: we are talking about rows and columns,
/// not lanes, major and minor dimensions.
fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseFormatError {
use SparsityPatternFormatError::*;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparseFormatError as E;
use SparseFormatErrorKind as K;
match err {
InvalidOffsetArrayLength => E::from_kind_and_msg(
K::InvalidStructure,
"Length of row offset array is not equal to nrows + 1."),
InvalidOffsetFirstLast => E::from_kind_and_msg(
K::InvalidStructure,
"First or last row offset is inconsistent with format specification."),
NonmonotonicOffsets => E::from_kind_and_msg(
K::InvalidStructure,
"Row offsets are not monotonically increasing."),
NonmonotonicMinorIndices => E::from_kind_and_msg(
K::InvalidStructure,
"Column indices are not monotonically increasing (sorted) within each row."),
MinorIndexOutOfBounds => E::from_kind_and_msg(
K::IndexOutOfBounds,
"Column indices are out of bounds."),
PatternDuplicateEntry => E::from_kind_and_msg(
K::DuplicateEntry,
"Matrix data contains duplicate entries."),
}
}
/// Iterator type for iterating over triplets in a CSR matrix.
#[derive(Debug)]
pub struct CsrTripletIter<'a, T> {
@ -360,7 +394,7 @@ macro_rules! impl_csr_row_common_methods {
/// bounds.
///
/// If the index is in bounds, but no explicitly stored entry is associated with it,
/// `T::zero()` is returned. Note that this methods offers no way of distinguishing
/// `T::zero()` is returned. Note that this method offers no way of distinguishing
/// explicitly stored zero entries from zero values that are only implicitly represented.
///
/// Each call to this function incurs the cost of a binary search among the explicitly

View File

@ -73,7 +73,7 @@ pub mod ops;
pub use coo::CooMatrix;
pub use csr::{CsrMatrix, CsrRow, CsrRowMut};
pub use pattern::{SparsityPattern};
pub use pattern::{SparsityPattern, SparsityPatternFormatError};
/// Iterator types for matrices.
///
@ -94,30 +94,52 @@ use std::fmt;
/// Errors produced by functions that expect well-formed sparse format data.
#[derive(Debug)]
#[non_exhaustive]
pub enum SparseFormatError {
pub struct SparseFormatError {
kind: SparseFormatErrorKind,
// Currently we only use an underlying error for generating the `Display` impl
error: Box<dyn Error>
}
impl SparseFormatError {
/// The type of error.
pub fn kind(&self) -> &SparseFormatErrorKind {
&self.kind
}
pub(crate) fn from_kind_and_error(kind: SparseFormatErrorKind, error: Box<dyn Error>) -> Self {
Self {
kind,
error
}
}
/// Helper functionality for more conveniently creating errors.
pub(crate) fn from_kind_and_msg(kind: SparseFormatErrorKind, msg: &'static str) -> Self {
Self::from_kind_and_error(kind, Box::<dyn Error>::from(msg))
}
}
/// The type of format error described by a [SparseFormatError](struct.SparseFormatError.html).
#[derive(Debug, Clone)]
pub enum SparseFormatErrorKind {
/// Indicates that the index data associated with the format contains at least one index
/// out of bounds.
IndexOutOfBounds(Box<dyn Error>),
IndexOutOfBounds,
/// Indicates that the provided data contains at least one duplicate entry, and the
/// current format does not support duplicate entries.
DuplicateEntry(Box<dyn Error>),
DuplicateEntry,
/// Indicates that the provided data for the format does not conform to the high-level
/// structure of the format.
///
/// For example, the arrays defining the format data might have incompatible sizes.
InvalidStructure(Box<dyn Error>),
InvalidStructure,
}
impl fmt::Display for SparseFormatError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::IndexOutOfBounds(err) => err.fmt(f),
Self::DuplicateEntry(err) => err.fmt(f),
Self::InvalidStructure(err) => err.fmt(f)
}
write!(f, "{}", self.error)
}
}

View File

@ -1,4 +1,6 @@
use crate::SparseFormatError;
use std::fmt;
use std::error::Error;
/// A representation of the sparsity pattern of a CSR or CSC matrix.
///
@ -79,27 +81,24 @@ impl SparsityPattern {
minor_dim: usize,
major_offsets: Vec<usize>,
minor_indices: Vec<usize>,
) -> Result<Self, SparseFormatError> {
) -> Result<Self, SparsityPatternFormatError> {
// TODO: If these errors are *directly* propagated to errors from e.g.
// CSR construction, the error messages will be confusing to users,
// as the error messages refer to "major" and "minor" lanes, as opposed to
// rows and columns
use SparsityPatternFormatError::*;
if major_offsets.len() != major_dim + 1 {
return Err(SparseFormatError::InvalidStructure(
Box::from("Size of major_offsets must be equal to (major_dim + 1)")));
return Err(InvalidOffsetArrayLength);
}
// Check that the first and last offsets conform to the specification
{
if *major_offsets.first().unwrap() != 0 {
return Err(SparseFormatError::InvalidStructure(
Box::from("First entry in major_offsets must always be 0.")
));
} else if *major_offsets.last().unwrap() != minor_indices.len() {
return Err(SparseFormatError::InvalidStructure(
Box::from("Last entry in major_offsets must always be equal to minor_indices.len()")
));
let first_offset_ok = *major_offsets.first().unwrap() == 0;
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
if !first_offset_ok || !last_offset_ok {
return Err(InvalidOffsetFirstLast);
}
}
@ -113,9 +112,7 @@ impl SparsityPattern {
// Test that major offsets are monotonically increasing
if range_start > range_end {
return Err(SparseFormatError::InvalidStructure(
Box::from("Major offsets are not monotonically increasing.")
));
return Err(NonmonotonicOffsets);
}
let minor_indices = &minor_indices[range_start .. range_end];
@ -127,20 +124,14 @@ impl SparsityPattern {
while let Some(next) = iter.next().copied() {
if next > minor_dim {
return Err(SparseFormatError::IndexOutOfBounds(
Box::from("Minor index out of bounds.")
));
return Err(MinorIndexOutOfBounds);
}
if let Some(prev) = prev {
if prev > next {
return Err(SparseFormatError::InvalidStructure(
Box::from("Minor indices within a lane must be monotonically increasing (sorted).")
));
return Err(NonmonotonicMinorIndices);
} else if prev == next {
return Err(SparseFormatError::DuplicateEntry(
Box::from("Duplicate minor entries detected.")
));
return Err(DuplicateEntry);
}
}
prev = Some(next);
@ -180,6 +171,82 @@ impl SparsityPattern {
}
}
/// Error type for `SparsityPattern` format errors.
#[non_exhaustive]
#[derive(Debug)]
pub enum SparsityPatternFormatError {
/// Indicates an invalid number of offsets.
///
/// The number of offsets must be equal to (major_dim + 1).
InvalidOffsetArrayLength,
/// Indicates that the first or last entry in the offset array did not conform to
/// specifications.
///
/// The first entry must be 0, and the last entry must be exactly one greater than the
/// major dimension.
InvalidOffsetFirstLast,
/// Indicates that the major offsets are not monotonically increasing.
NonmonotonicOffsets,
/// One or more minor indices are out of bounds.
MinorIndexOutOfBounds,
/// One or more duplicate entries were detected.
///
/// Two entries are considered duplicates if they are part of the same major lane and have
/// the same minor index.
DuplicateEntry,
/// Indicates that minor indices are not monotonically increasing within each lane.
NonmonotonicMinorIndices,
}
impl From<SparsityPatternFormatError> for SparseFormatError {
fn from(err: SparsityPatternFormatError) -> Self {
use SparsityPatternFormatError::*;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use crate::SparseFormatErrorKind;
use crate::SparseFormatErrorKind::*;
match err {
InvalidOffsetArrayLength
| InvalidOffsetFirstLast
| NonmonotonicOffsets
| NonmonotonicMinorIndices
=> SparseFormatError::from_kind_and_error(InvalidStructure, Box::from(err)),
MinorIndexOutOfBounds
=> SparseFormatError::from_kind_and_error(IndexOutOfBounds,
Box::from(err)),
PatternDuplicateEntry
=> SparseFormatError::from_kind_and_error(SparseFormatErrorKind::DuplicateEntry,
Box::from(err)),
}
}
}
impl fmt::Display for SparsityPatternFormatError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SparsityPatternFormatError::InvalidOffsetArrayLength => {
write!(f, "Length of offset array is not equal to (major_dim + 1).")
},
SparsityPatternFormatError::InvalidOffsetFirstLast => {
write!(f, "First or last offset is incompatible with format.")
},
SparsityPatternFormatError::NonmonotonicOffsets => {
write!(f, "Offsets are not monotonically increasing.")
},
SparsityPatternFormatError::MinorIndexOutOfBounds => {
write!(f, "A minor index is out of bounds.")
},
SparsityPatternFormatError::DuplicateEntry => {
write!(f, "Input data contains duplicate entries.")
},
SparsityPatternFormatError::NonmonotonicMinorIndices => {
write!(f, "Minor indices are not monotonically increasing within each lane.")
},
}
}
}
impl Error for SparsityPatternFormatError {}
/// Iterator type for iterating over entries in a sparsity pattern.
#[derive(Debug, Clone)]
pub struct SparsityPatternIter<'a> {

View File

@ -1,4 +1,4 @@
use nalgebra_sparse::{CooMatrix, SparseFormatError};
use nalgebra_sparse::{CooMatrix, SparseFormatErrorKind};
use nalgebra::DMatrix;
use crate::assert_panics;
@ -91,25 +91,25 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
{
// 0x0 matrix
let result = CooMatrix::<i32>::try_from_triplets(0, 0, vec![0], vec![0], vec![2]);
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
}
{
// 1x1 matrix, row out of bounds
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![0], vec![2]);
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
}
{
// 1x1 matrix, col out of bounds
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![0], vec![1], vec![2]);
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
}
{
// 1x1 matrix, row and col out of bounds
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![1], vec![2]);
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
}
{
@ -118,7 +118,7 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
let j = vec![0, 2, 1, 3, 3];
let v = vec![2, 3, 7, 3, 1];
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
}
{
@ -127,7 +127,7 @@ fn coo_try_from_triplets_reports_out_of_bounds_indices() {
let j = vec![0, 2, 1, 5, 3];
let v = vec![2, 3, 7, 3, 1];
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
assert!(matches!(result, Err(SparseFormatError::IndexOutOfBounds(_))));
assert!(matches!(result.unwrap_err().kind(), SparseFormatErrorKind::IndexOutOfBounds));
}
}
@ -136,7 +136,7 @@ fn coo_try_from_triplets_panics_on_mismatched_vectors() {
// Check that try_from_triplets panics when the triplet vectors have different lengths
macro_rules! assert_errs {
($result:expr) => {
assert!(matches!($result, Err(SparseFormatError::InvalidStructure(_))))
assert!(matches!($result.unwrap_err().kind(), SparseFormatErrorKind::InvalidStructure))
}
}