Add tests for some csr matrix related failure cases

This commit is contained in:
Anton 2021-10-13 21:18:17 +02:00
parent 4a97989738
commit 4b41be75b0
2 changed files with 109 additions and 6 deletions

View File

@ -178,7 +178,8 @@ impl<T> CsrMatrix<T> {
/// bounds with respect to the number of columns in the matrix. If this is not the case, /// bounds with respect to the number of columns in the matrix. If this is not the case,
/// an error is returned to indicate the failure. /// an error is returned to indicate the failure.
/// ///
/// An error is returned if the data given does not conform to the CSR storage format. /// An error is returned if the data given does not conform to the CSR storage format
/// with the exception of having unsorted column indices and values.
/// See the documentation for [CsrMatrix](struct.CsrMatrix.html) for more information. /// See the documentation for [CsrMatrix](struct.CsrMatrix.html) for more information.
pub fn try_from_unsorted_csr_data( pub fn try_from_unsorted_csr_data(
num_rows: usize, num_rows: usize,
@ -190,6 +191,7 @@ impl<T> CsrMatrix<T> {
where where
T: Scalar + Zero, T: Scalar + Zero,
{ {
use SparsityPatternFormatError::*;
let count = col_indices.len(); let count = col_indices.len();
let mut p: Vec<usize> = (0..count).collect(); let mut p: Vec<usize> = (0..count).collect();
@ -215,6 +217,9 @@ impl<T> CsrMatrix<T> {
"No row offset should be greater than the number of column indices", "No row offset should be greater than the number of column indices",
)); ));
} }
if offset > next_offset {
return Err(NonmonotonicOffsets).map_err(pattern_format_error_to_csr_error);
}
p[offset..next_offset].sort_by(|a, b| { p[offset..next_offset].sort_by(|a, b| {
let x = &col_indices[*a]; let x = &col_indices[*a];
let y = &col_indices[*b]; let y = &col_indices[*b];
@ -226,15 +231,15 @@ impl<T> CsrMatrix<T> {
let sorted_col_indices: Vec<usize> = p.iter().map(|i| col_indices[*i]).collect(); let sorted_col_indices: Vec<usize> = p.iter().map(|i| col_indices[*i]).collect();
// permute values // permute values
let mut sorted_vaues: Vec<T> = vec![T::zero(); count]; let mut sorted_values: Vec<T> = vec![T::zero(); count];
apply_permutation(&mut sorted_vaues[..count], &values[..count], &p[..count]); apply_permutation(&mut sorted_values, &values, &p);
return Self::try_from_csr_data( return Self::try_from_csr_data(
num_rows, num_rows,
num_cols, num_cols,
row_offsets, row_offsets,
sorted_col_indices, sorted_col_indices,
sorted_vaues, sorted_values,
); );
} }

View File

@ -178,7 +178,7 @@ fn csr_matrix_valid_data_unsorted_column_indices() {
4, 4,
vec![0, 1, 2, 5], vec![0, 1, 2, 5],
vec![1, 3, 2, 3, 0], vec![1, 3, 2, 3, 0],
vec![5, 4, 1, 4, 1], vec![5, 4, 2, 3, 1],
) )
.unwrap(); .unwrap();
@ -187,13 +187,111 @@ fn csr_matrix_valid_data_unsorted_column_indices() {
4, 4,
vec![0, 1, 2, 5], vec![0, 1, 2, 5],
vec![1, 3, 0, 2, 3], vec![1, 3, 0, 2, 3],
vec![5, 4, 1, 1, 4], vec![5, 4, 1, 2, 3],
) )
.unwrap(); .unwrap();
assert_eq!(csr, expected_csr); assert_eq!(csr, expected_csr);
} }
#[test]
fn csr_matrix_try_from_invalid_csr_data2() {
{
// Empty offset array (invalid length)
let matrix =
CsrMatrix::try_from_unsorted_csr_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Offset array invalid length for arbitrary data
let offsets = vec![0, 3, 5];
let indices = vec![0, 1, 2, 3, 5];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_unsorted_csr_data(3, 6, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Invalid first entry in offsets array
let offsets = vec![1, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_unsorted_csr_data(3, 6, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Invalid last entry in offsets array
let offsets = vec![0, 2, 2, 4];
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_unsorted_csr_data(3, 6, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Invalid length of offsets array
let offsets = vec![0, 2, 2];
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_unsorted_csr_data(3, 6, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Nonmonotonic offsets
let offsets = vec![0, 3, 2, 5];
let indices = vec![0, 1, 2, 3, 4];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_unsorted_csr_data(3, 6, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Minor index out of bounds
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 6, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_unsorted_csr_data(3, 6, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::IndexOutOfBounds
);
}
{
// Duplicate entry
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 5, 2, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_unsorted_csr_data(3, 6, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::DuplicateEntry
);
}
}
#[test] #[test]
fn csr_matrix_try_from_invalid_csr_data() { fn csr_matrix_try_from_invalid_csr_data() {
{ {