Implement Serialize, Deserialize for SparsityPattern

This commit is contained in:
Fabian Loeschner 2021-11-09 10:59:24 +01:00
parent 2a3e657b56
commit bfaf29393c
2 changed files with 93 additions and 0 deletions

View File

@ -4,6 +4,9 @@ use crate::SparseFormatError;
use std::error::Error;
use std::fmt;
#[cfg(feature = "serde-serialize")]
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
/// A representation of the sparsity pattern of a CSR or CSC matrix.
///
/// CSR and CSC matrices store matrices in a very similar fashion. In fact, in a certain sense,
@ -285,6 +288,54 @@ pub enum SparsityPatternFormatError {
NonmonotonicMinorIndices,
}
#[cfg(feature = "serde-serialize")]
#[derive(Serialize)]
struct SparsityPatternSerializationData<'a> {
major_dim: usize,
minor_dim: usize,
major_offsets: &'a [usize],
minor_indices: &'a [usize],
}
#[cfg(feature = "serde-serialize")]
impl Serialize for SparsityPattern {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
SparsityPatternSerializationData {
major_dim: self.major_dim(),
minor_dim: self.minor_dim(),
major_offsets: self.major_offsets(),
minor_indices: self.minor_indices(),
}
.serialize(serializer)
}
}
#[cfg(feature = "serde-serialize")]
#[derive(Deserialize)]
struct SparsityPatternDeserializationData {
major_dim: usize,
minor_dim: usize,
major_offsets: Vec<usize>,
minor_indices: Vec<usize>,
}
#[cfg(feature = "serde-serialize")]
impl<'de> Deserialize<'de> for SparsityPattern {
fn deserialize<D>(deserializer: D) -> Result<SparsityPattern, D::Error>
where
D: Deserializer<'de>,
{
let de = SparsityPatternDeserializationData::deserialize(deserializer)?;
SparsityPattern::try_from_offsets_and_indices(de.major_dim, de.minor_dim, de.major_offsets, de.minor_indices)
.map(|m| m.into())
// TODO: More specific error
.map_err(|_e| de::Error::invalid_value(de::Unexpected::Other("invalid sparsity pattern"), &"a valid sparsity pattern"))
}
}
impl From<SparsityPatternFormatError> for SparseFormatError {
fn from(err: SparsityPatternFormatError) -> Self {
use crate::SparseFormatErrorKind;

View File

@ -9,6 +9,7 @@ pub mod common;
use nalgebra_sparse::coo::CooMatrix;
use nalgebra_sparse::csc::CscMatrix;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::pattern::SparsityPattern;
use proptest::prelude::*;
use serde::{Deserialize, Serialize};
@ -21,6 +22,44 @@ fn json_roundtrip<T: Serialize + for<'a> Deserialize<'a>>(csr: &T) -> T {
deserialized
}
#[test]
fn pattern_roundtrip() {
{
// A pattern with zero explicitly stored entries
let pattern =
SparsityPattern::try_from_offsets_and_indices(3, 2, vec![0, 0, 0, 0], Vec::new())
.unwrap();
assert_eq!(json_roundtrip(&pattern), pattern);
}
{
// Arbitrary pattern
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3];
let pattern =
SparsityPattern::try_from_offsets_and_indices(3, 6, offsets.clone(), indices.clone())
.unwrap();
assert_eq!(json_roundtrip(&pattern), pattern);
}
}
#[test]
#[rustfmt::skip]
fn pattern_deserialize_invalid() {
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0,2,2,5],"minor_indices":[0,5,1,2,3]}"#).is_ok());
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":0,"minor_dim":0,"major_offsets":[],"minor_indices":[]}"#).is_err());
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 3, 5],"minor_indices":[0, 1, 2, 3, 5]}"#).is_err());
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[1, 2, 2, 5],"minor_indices":[0, 5, 1, 2, 3]}"#).is_err());
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 2, 2, 4],"minor_indices":[0, 5, 1, 2, 3]}"#).is_err());
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 2, 2],"minor_indices":[0, 5, 1, 2, 3]}"#).is_err());
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 3, 2, 5],"minor_indices":[0, 1, 2, 3, 4]}"#).is_err());
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 2, 2, 5],"minor_indices":[0, 2, 3, 1, 4]}"#).is_err());
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 2, 2, 5],"minor_indices":[0, 6, 1, 2, 3]}"#).is_err());
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 2, 2, 5],"minor_indices":[0, 5, 2, 2, 3]}"#).is_err());
}
#[test]
fn coo_roundtrip() {
{
@ -46,6 +85,7 @@ fn coo_roundtrip() {
#[test]
fn coo_deserialize_invalid() {
// Valid matrix: {"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,1]}
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,1]}"#).is_ok());
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":0,"ncols":0,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,4]}"#).is_err());
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":-3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,4]}"#).is_err());
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3]}"#).is_err());
@ -101,6 +141,7 @@ fn csc_roundtrip() {
#[test]
fn csc_deserialize_invalid() {
// Valid matrix: {"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}
assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_ok());
assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":0,"ncols":0,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":-6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3]}"#).is_err());
@ -140,6 +181,7 @@ fn csr_roundtrip() {
#[test]
fn csr_deserialize_invalid() {
// Valid matrix: {"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_ok());
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":0,"ncols":0,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":-3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3]}"#).is_err());