From bfaf29393c546e58c0effee163fb38b0bdc0aca0 Mon Sep 17 00:00:00 2001 From: Fabian Loeschner Date: Tue, 9 Nov 2021 10:59:24 +0100 Subject: [PATCH] Implement Serialize, Deserialize for SparsityPattern --- nalgebra-sparse/src/pattern.rs | 51 ++++++++++++++++++++++++++++++++++ nalgebra-sparse/tests/serde.rs | 42 ++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/nalgebra-sparse/src/pattern.rs b/nalgebra-sparse/src/pattern.rs index 85f6bc1a..4243543d 100644 --- a/nalgebra-sparse/src/pattern.rs +++ b/nalgebra-sparse/src/pattern.rs @@ -4,6 +4,9 @@ use crate::SparseFormatError; use std::error::Error; use std::fmt; +#[cfg(feature = "serde-serialize")] +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; + /// A representation of the sparsity pattern of a CSR or CSC matrix. /// /// CSR and CSC matrices store matrices in a very similar fashion. In fact, in a certain sense, @@ -285,6 +288,54 @@ pub enum SparsityPatternFormatError { NonmonotonicMinorIndices, } +#[cfg(feature = "serde-serialize")] +#[derive(Serialize)] +struct SparsityPatternSerializationData<'a> { + major_dim: usize, + minor_dim: usize, + major_offsets: &'a [usize], + minor_indices: &'a [usize], +} + +#[cfg(feature = "serde-serialize")] +impl Serialize for SparsityPattern { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + SparsityPatternSerializationData { + major_dim: self.major_dim(), + minor_dim: self.minor_dim(), + major_offsets: self.major_offsets(), + minor_indices: self.minor_indices(), + } + .serialize(serializer) + } +} + +#[cfg(feature = "serde-serialize")] +#[derive(Deserialize)] +struct SparsityPatternDeserializationData { + major_dim: usize, + minor_dim: usize, + major_offsets: Vec, + minor_indices: Vec, +} + +#[cfg(feature = "serde-serialize")] +impl<'de> Deserialize<'de> for SparsityPattern { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let de = SparsityPatternDeserializationData::deserialize(deserializer)?; + SparsityPattern::try_from_offsets_and_indices(de.major_dim, de.minor_dim, de.major_offsets, de.minor_indices) + .map(|m| m.into()) + // TODO: More specific error + .map_err(|_e| de::Error::invalid_value(de::Unexpected::Other("invalid sparsity pattern"), &"a valid sparsity pattern")) + } +} + impl From for SparseFormatError { fn from(err: SparsityPatternFormatError) -> Self { use crate::SparseFormatErrorKind; diff --git a/nalgebra-sparse/tests/serde.rs b/nalgebra-sparse/tests/serde.rs index 5afe93c5..1ce1953f 100644 --- a/nalgebra-sparse/tests/serde.rs +++ b/nalgebra-sparse/tests/serde.rs @@ -9,6 +9,7 @@ pub mod common; use nalgebra_sparse::coo::CooMatrix; use nalgebra_sparse::csc::CscMatrix; use nalgebra_sparse::csr::CsrMatrix; +use nalgebra_sparse::pattern::SparsityPattern; use proptest::prelude::*; use serde::{Deserialize, Serialize}; @@ -21,6 +22,44 @@ fn json_roundtrip Deserialize<'a>>(csr: &T) -> T { deserialized } +#[test] +fn pattern_roundtrip() { + { + // A pattern with zero explicitly stored entries + let pattern = + SparsityPattern::try_from_offsets_and_indices(3, 2, vec![0, 0, 0, 0], Vec::new()) + .unwrap(); + + assert_eq!(json_roundtrip(&pattern), pattern); + } + + { + // Arbitrary pattern + let offsets = vec![0, 2, 2, 5]; + let indices = vec![0, 5, 1, 2, 3]; + let pattern = + SparsityPattern::try_from_offsets_and_indices(3, 6, offsets.clone(), indices.clone()) + .unwrap(); + + assert_eq!(json_roundtrip(&pattern), pattern); + } +} + +#[test] +#[rustfmt::skip] +fn pattern_deserialize_invalid() { + assert!(serde_json::from_str::(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0,2,2,5],"minor_indices":[0,5,1,2,3]}"#).is_ok()); + assert!(serde_json::from_str::(r#"{"major_dim":0,"minor_dim":0,"major_offsets":[],"minor_indices":[]}"#).is_err()); + assert!(serde_json::from_str::(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 3, 5],"minor_indices":[0, 1, 2, 3, 5]}"#).is_err()); + assert!(serde_json::from_str::(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[1, 2, 2, 5],"minor_indices":[0, 5, 1, 2, 3]}"#).is_err()); + assert!(serde_json::from_str::(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 2, 2, 4],"minor_indices":[0, 5, 1, 2, 3]}"#).is_err()); + assert!(serde_json::from_str::(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 2, 2],"minor_indices":[0, 5, 1, 2, 3]}"#).is_err()); + assert!(serde_json::from_str::(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 3, 2, 5],"minor_indices":[0, 1, 2, 3, 4]}"#).is_err()); + assert!(serde_json::from_str::(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 2, 2, 5],"minor_indices":[0, 2, 3, 1, 4]}"#).is_err()); + assert!(serde_json::from_str::(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 2, 2, 5],"minor_indices":[0, 6, 1, 2, 3]}"#).is_err()); + assert!(serde_json::from_str::(r#"{"major_dim":3,"minor_dim":6,"major_offsets":[0, 2, 2, 5],"minor_indices":[0, 5, 2, 2, 3]}"#).is_err()); +} + #[test] fn coo_roundtrip() { { @@ -46,6 +85,7 @@ fn coo_roundtrip() { #[test] fn coo_deserialize_invalid() { // Valid matrix: {"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,1]} + assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,1]}"#).is_ok()); assert!(serde_json::from_str::>(r#"{"nrows":0,"ncols":0,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,4]}"#).is_err()); assert!(serde_json::from_str::>(r#"{"nrows":-3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,4]}"#).is_err()); assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3]}"#).is_err()); @@ -101,6 +141,7 @@ fn csc_roundtrip() { #[test] fn csc_deserialize_invalid() { // Valid matrix: {"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]} + assert!(serde_json::from_str::>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_ok()); assert!(serde_json::from_str::>(r#"{"nrows":0,"ncols":0,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err()); assert!(serde_json::from_str::>(r#"{"nrows":-6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err()); assert!(serde_json::from_str::>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3]}"#).is_err()); @@ -140,6 +181,7 @@ fn csr_roundtrip() { #[test] fn csr_deserialize_invalid() { // Valid matrix: {"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]} + assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_ok()); assert!(serde_json::from_str::>(r#"{"nrows":0,"ncols":0,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err()); assert!(serde_json::from_str::>(r#"{"nrows":-3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err()); assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3]}"#).is_err());