Merge SolveError into OperationError

This commit is contained in:
Andreas Longva 2021-01-25 17:05:13 +01:00
parent 7b6333e9d1
commit ccf1f18991
2 changed files with 34 additions and 35 deletions

View File

@ -1,7 +1,7 @@
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use crate::ops::Op; use crate::ops::Op;
use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc}; use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc};
use crate::ops::serial::OperationError; use crate::ops::serial::{OperationError, OperationErrorKind};
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice, RealField}; use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice, RealField};
use num_traits::{Zero, One}; use num_traits::{Zero, One};
@ -91,30 +91,6 @@ pub fn spmm_csc_prealloc<T>(
} }
} }
/// TODO
#[non_exhaustive]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum SolveErrorKind {
/// TODO
Singular,
}
/// TODO
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SolveError {
kind: SolveErrorKind,
message: String
}
impl SolveError {
fn from_type_and_message(kind: SolveErrorKind, message: String) -> Self {
Self {
kind,
message
}
}
}
/// Solve the lower triangular system `op(L) X = B`. /// Solve the lower triangular system `op(L) X = B`.
/// ///
/// Only the lower triangular part of L is read, and the result is stored in B. /// Only the lower triangular part of L is read, and the result is stored in B.
@ -125,7 +101,7 @@ impl SolveError {
pub fn spsolve_csc_lower_triangular<'a, T: RealField>( pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
l: Op<&CscMatrix<T>>, l: Op<&CscMatrix<T>>,
b: impl Into<DMatrixSliceMut<'a, T>>) b: impl Into<DMatrixSliceMut<'a, T>>)
-> Result<(), SolveError> -> Result<(), OperationError>
{ {
let b = b.into(); let b = b.into();
let l_matrix = l.into_inner(); let l_matrix = l.into_inner();
@ -137,10 +113,10 @@ pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
} }
} }
fn spsolve_csc_lower_triangular_no_transpose<'a, T: RealField>( fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
l: &CscMatrix<T>, l: &CscMatrix<T>,
b: DMatrixSliceMut<'a, T>) b: DMatrixSliceMut<T>)
-> Result<(), SolveError> -> Result<(), OperationError>
{ {
let mut x = b; let mut x = b;
@ -188,15 +164,15 @@ fn spsolve_csc_lower_triangular_no_transpose<'a, T: RealField>(
Ok(()) Ok(())
} }
fn spsolve_encountered_zero_diagonal() -> Result<(), SolveError> { fn spsolve_encountered_zero_diagonal() -> Result<(), OperationError> {
let message = "Matrix contains at least one diagonal entry that is zero."; let message = "Matrix contains at least one diagonal entry that is zero.";
Err(SolveError::from_type_and_message(SolveErrorKind::Singular, String::from(message))) Err(OperationError::from_kind_and_message(OperationErrorKind::Singular, String::from(message)))
} }
fn spsolve_csc_lower_triangular_transpose<'a, T: RealField>( fn spsolve_csc_lower_triangular_transpose<T: RealField>(
l: &CscMatrix<T>, l: &CscMatrix<T>,
b: DMatrixSliceMut<'a, T>) b: DMatrixSliceMut<T>)
-> Result<(), SolveError> -> Result<(), OperationError>
{ {
let mut x = b; let mut x = b;

View File

@ -65,6 +65,8 @@ mod cs;
pub use csc::*; pub use csc::*;
pub use csr::*; pub use csr::*;
pub use pattern::*; pub use pattern::*;
use std::fmt::Formatter;
use std::fmt;
/// A description of the error that occurred during an arithmetic operation. /// A description of the error that occurred during an arithmetic operation.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -83,6 +85,9 @@ pub enum OperationErrorKind {
/// For example, this could indicate that the sparsity pattern of the output is not able to /// For example, this could indicate that the sparsity pattern of the output is not able to
/// contain the result of the operation. /// contain the result of the operation.
InvalidPattern, InvalidPattern,
/// Indicates that a matrix is singular when it is expected to be invertible.
Singular,
} }
impl OperationError { impl OperationError {
@ -94,4 +99,22 @@ impl OperationError {
pub fn kind(&self) -> &OperationErrorKind { pub fn kind(&self) -> &OperationErrorKind {
&self.error_kind &self.error_kind
} }
/// The underlying error message.
pub fn message(&self) -> &str {
self.message.as_str()
}
} }
impl fmt::Display for OperationError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Sparse matrix operation error: ")?;
match self.kind() {
OperationErrorKind::InvalidPattern => { write!(f, "InvalidPattern")?; }
OperationErrorKind::Singular => { write!(f, "Singular")?; }
}
write!(f, " Message: {}", self.message)
}
}
impl std::error::Error for OperationError {}