Implement Serialize, Deserialize for SparsityPattern
This commit is contained in:
parent
2a3e657b56
commit
bfaf29393c
|
@ -4,6 +4,9 @@ use crate::SparseFormatError;
|
|||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
/// A representation of the sparsity pattern of a CSR or CSC matrix.
|
||||
///
|
||||
/// CSR and CSC matrices store matrices in a very similar fashion. In fact, in a certain sense,
|
||||
|
@ -285,6 +288,54 @@ pub enum SparsityPatternFormatError {
|
|||
NonmonotonicMinorIndices,
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
#[derive(Serialize)]
|
||||
struct SparsityPatternSerializationData<'a> {
|
||||
major_dim: usize,
|
||||
minor_dim: usize,
|
||||
major_offsets: &'a [usize],
|
||||
minor_indices: &'a [usize],
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
impl Serialize for SparsityPattern {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
SparsityPatternSerializationData {
|
||||
major_dim: self.major_dim(),
|
||||
minor_dim: self.minor_dim(),
|
||||
major_offsets: self.major_offsets(),
|
||||
minor_indices: self.minor_indices(),
|
||||
}
|
||||
.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
#[derive(Deserialize)]
|
||||
struct SparsityPatternDeserializationData {
|
||||
major_dim: usize,
|
||||
minor_dim: usize,
|
||||
major_offsets: Vec<usize>,
|
||||
minor_indices: Vec<usize>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
impl<'de> Deserialize<'de> for SparsityPattern {
|
||||
fn deserialize<D>(deserializer: D) -> Result<SparsityPattern, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let de = SparsityPatternDeserializationData::deserialize(deserializer)?;
|
||||
SparsityPattern::try_from_offsets_and_indices(de.major_dim, de.minor_dim, de.major_offsets, de.minor_indices)
|
||||
.map(|m| m.into())
|
||||
// TODO: More specific error
|
||||
.map_err(|_e| de::Error::invalid_value(de::Unexpected::Other("invalid sparsity pattern"), &"a valid sparsity pattern"))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SparsityPatternFormatError> for SparseFormatError {
|
||||
fn from(err: SparsityPatternFormatError) -> Self {
|
||||
use crate::SparseFormatErrorKind;
|
||||
|
|
|
@ -9,6 +9,7 @@ pub mod common;
|
|||
use nalgebra_sparse::coo::CooMatrix;
|
||||
use nalgebra_sparse::csc::CscMatrix;
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::pattern::SparsityPattern;
|
||||
|
||||
use proptest::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -21,6 +22,44 @@ fn json_roundtrip<T: Serialize + for<'a> Deserialize<'a>>(csr: &T) -> T {
|
|||
deserialized
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pattern_roundtrip() {
|
||||
{
|
||||
// A pattern with zero explicitly stored entries
|
||||
let pattern =
|
||||
SparsityPattern::try_from_offsets_and_indices(3, 2, vec![0, 0, 0, 0], Vec::new())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(json_roundtrip(&pattern), pattern);
|
||||
}
|
||||
|
||||
{
|
||||
// Arbitrary pattern
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let pattern =
|
||||
SparsityPattern::try_from_offsets_and_indices(3, 6, offsets.clone(), indices.clone())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(json_roundtrip(&pattern), pattern);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[rustfmt::skip]
|
||||
fn pattern_deserialize_invalid() {
|
||||
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0,2,2,5],"minor_indices":[0,5,1,2,3]}"#).is_ok());
|
||||
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":0,"minor_dim":0,"major_offsets":[],"minor_indices":[]}"#).is_err());
|
||||
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 3, 5],"minor_indices":[0, 1, 2, 3, 5]}"#).is_err());
|
||||
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[1, 2, 2, 5],"minor_indices":[0, 5, 1, 2, 3]}"#).is_err());
|
||||
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 2, 2, 4],"minor_indices":[0, 5, 1, 2, 3]}"#).is_err());
|
||||
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 2, 2],"minor_indices":[0, 5, 1, 2, 3]}"#).is_err());
|
||||
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 3, 2, 5],"minor_indices":[0, 1, 2, 3, 4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 2, 2, 5],"minor_indices":[0, 2, 3, 1, 4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 2, 2, 5],"minor_indices":[0, 6, 1, 2, 3]}"#).is_err());
|
||||
assert!(serde_json::from_str::<SparsityPattern>(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 2, 2, 5],"minor_indices":[0, 5, 2, 2, 3]}"#).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coo_roundtrip() {
|
||||
{
|
||||
|
@ -46,6 +85,7 @@ fn coo_roundtrip() {
|
|||
#[test]
|
||||
fn coo_deserialize_invalid() {
|
||||
// Valid matrix: {"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,1]}
|
||||
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,1]}"#).is_ok());
|
||||
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":0,"ncols":0,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":-3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3]}"#).is_err());
|
||||
|
@ -101,6 +141,7 @@ fn csc_roundtrip() {
|
|||
#[test]
|
||||
fn csc_deserialize_invalid() {
|
||||
// Valid matrix: {"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}
|
||||
assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_ok());
|
||||
assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":0,"ncols":0,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":-6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3]}"#).is_err());
|
||||
|
@ -140,6 +181,7 @@ fn csr_roundtrip() {
|
|||
#[test]
|
||||
fn csr_deserialize_invalid() {
|
||||
// Valid matrix: {"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}
|
||||
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_ok());
|
||||
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":0,"ncols":0,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":-3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3]}"#).is_err());
|
||||
|
|
Loading…
Reference in New Issue