From 40d8a904a3e5c1f486adfbc40fb2562254985c64 Mon Sep 17 00:00:00 2001 From: Fabian Loeschner Date: Tue, 9 Nov 2021 10:30:02 +0100 Subject: [PATCH] Implement Serialize, Deserialize for Csc, Coo; move helper out of impls --- nalgebra-sparse/src/coo.rs | 60 ++++++++++++++++++ nalgebra-sparse/src/csc.rs | 59 ++++++++++++++++++ nalgebra-sparse/src/csr.rs | 40 ++++++------ nalgebra-sparse/tests/serde.rs | 110 +++++++++++++++++++++++++++++++-- 4 files changed, 245 insertions(+), 24 deletions(-) diff --git a/nalgebra-sparse/src/coo.rs b/nalgebra-sparse/src/coo.rs index 34e5ceec..35a14083 100644 --- a/nalgebra-sparse/src/coo.rs +++ b/nalgebra-sparse/src/coo.rs @@ -2,6 +2,9 @@ use crate::SparseFormatError; +#[cfg(feature = "serde-serialize")] +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; + /// A COO representation of a sparse matrix. /// /// A COO matrix stores entries in coordinate-form, that is triplets `(i, j, v)`, where `i` and `j` @@ -273,3 +276,60 @@ impl CooMatrix { (self.row_indices, self.col_indices, self.values) } } + +#[cfg(feature = "serde-serialize")] +#[derive(Serialize)] +struct CooMatrixSerializationData<'a, T> { + nrows: usize, + ncols: usize, + row_indices: &'a [usize], + col_indices: &'a [usize], + values: &'a [T], +} + +#[cfg(feature = "serde-serialize")] +impl Serialize for CooMatrix +where + T: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + CooMatrixSerializationData { + nrows: self.nrows(), + ncols: self.ncols(), + row_indices: self.row_indices(), + col_indices: self.col_indices(), + values: self.values(), + } + .serialize(serializer) + } +} + +#[cfg(feature = "serde-serialize")] +#[derive(Deserialize)] +struct CooMatrixDeserializationData { + nrows: usize, + ncols: usize, + row_indices: Vec, + col_indices: Vec, + values: Vec, +} + +#[cfg(feature = "serde-serialize")] +impl<'de, T> Deserialize<'de> for CooMatrix +where + T: for<'de2> Deserialize<'de2>, +{ + fn deserialize(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let de = CooMatrixDeserializationData::deserialize(deserializer)?; + CooMatrix::try_from_triplets(de.nrows, de.ncols, de.row_indices, de.col_indices, de.values) + .map(|m| m.into()) + // TODO: More specific error + .map_err(|_e| de::Error::invalid_value(de::Unexpected::Other("invalid COO matrix"), &"a valid COO matrix")) + } +} diff --git a/nalgebra-sparse/src/csc.rs b/nalgebra-sparse/src/csc.rs index 607cc0cf..cb7cb79b 100644 --- a/nalgebra-sparse/src/csc.rs +++ b/nalgebra-sparse/src/csc.rs @@ -10,6 +10,8 @@ use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKin use nalgebra::Scalar; use num_traits::One; +#[cfg(feature = "serde-serialize")] +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; use std::slice::{Iter, IterMut}; /// A CSC representation of a sparse matrix. @@ -524,6 +526,63 @@ impl CscMatrix { } } +#[cfg(feature = "serde-serialize")] +#[derive(Serialize)] +struct CscMatrixSerializationData<'a, T> { + nrows: usize, + ncols: usize, + col_offsets: &'a [usize], + row_indices: &'a [usize], + values: &'a [T], +} + +#[cfg(feature = "serde-serialize")] +impl Serialize for CscMatrix +where + T: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + CscMatrixSerializationData { + nrows: self.nrows(), + ncols: self.ncols(), + col_offsets: self.col_offsets(), + row_indices: self.row_indices(), + values: self.values(), + } + .serialize(serializer) + } +} + +#[cfg(feature = "serde-serialize")] +#[derive(Deserialize)] +struct CscMatrixDeserializationData { + nrows: usize, + ncols: usize, + col_offsets: Vec, + row_indices: Vec, + values: Vec, +} + +#[cfg(feature = "serde-serialize")] +impl<'de, T> Deserialize<'de> for CscMatrix +where + T: for<'de2> Deserialize<'de2>, +{ + fn deserialize(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let de = CscMatrixDeserializationData::deserialize(deserializer)?; + CscMatrix::try_from_csc_data(de.nrows, de.ncols, de.col_offsets, de.row_indices, de.values) + .map(|m| m.into()) + // TODO: More specific error + .map_err(|_e| de::Error::invalid_value(de::Unexpected::Other("invalid CSC matrix"), &"a valid CSC matrix")) + } +} + /// Convert pattern format errors into more meaningful CSC-specific errors. /// /// This ensures that the terminology is consistent: we are talking about rows and columns, diff --git a/nalgebra-sparse/src/csr.rs b/nalgebra-sparse/src/csr.rs index bc2a3df0..b36fbb2f 100644 --- a/nalgebra-sparse/src/csr.rs +++ b/nalgebra-sparse/src/csr.rs @@ -9,10 +9,8 @@ use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKin use nalgebra::Scalar; use num_traits::One; - #[cfg(feature = "serde-serialize")] use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; - use std::iter::FromIterator; use std::slice::{Iter, IterMut}; @@ -599,6 +597,16 @@ impl CsrMatrix { } } +#[cfg(feature = "serde-serialize")] +#[derive(Serialize)] +struct CsrMatrixSerializationData<'a, T> { + nrows: usize, + ncols: usize, + row_offsets: &'a [usize], + col_indices: &'a [usize], + values: &'a [T], +} + #[cfg(feature = "serde-serialize")] impl Serialize for CsrMatrix where @@ -608,15 +616,6 @@ where where S: Serializer, { - #[derive(Serialize)] - struct CsrMatrixSerializationData<'a, T> { - nrows: usize, - ncols: usize, - row_offsets: &'a [usize], - col_indices: &'a [usize], - values: &'a [T], - } - CsrMatrixSerializationData { nrows: self.nrows(), ncols: self.ncols(), @@ -628,6 +627,16 @@ where } } +#[cfg(feature = "serde-serialize")] +#[derive(Deserialize)] +struct CsrMatrixDeserializationData { + nrows: usize, + ncols: usize, + row_offsets: Vec, + col_indices: Vec, + values: Vec, +} + #[cfg(feature = "serde-serialize")] impl<'de, T> Deserialize<'de> for CsrMatrix where @@ -637,15 +646,6 @@ where where D: Deserializer<'de>, { - #[derive(Deserialize)] - struct CsrMatrixDeserializationData { - nrows: usize, - ncols: usize, - row_offsets: Vec, - col_indices: Vec, - values: Vec, - } - let de = CsrMatrixDeserializationData::deserialize(deserializer)?; CsrMatrix::try_from_csr_data(de.nrows, de.ncols, de.row_offsets, de.col_indices, de.values) .map(|m| m.into()) diff --git a/nalgebra-sparse/tests/serde.rs b/nalgebra-sparse/tests/serde.rs index ecee76d1..5afe93c5 100644 --- a/nalgebra-sparse/tests/serde.rs +++ b/nalgebra-sparse/tests/serde.rs @@ -6,19 +6,112 @@ compile_error!("Tests must be run with features `proptest-support` and `compare` #[macro_use] pub mod common; +use nalgebra_sparse::coo::CooMatrix; +use nalgebra_sparse::csc::CscMatrix; use nalgebra_sparse::csr::CsrMatrix; use proptest::prelude::*; use serde::{Deserialize, Serialize}; -use crate::common::csr_strategy; +use crate::common::{csc_strategy, csr_strategy}; -fn json_roundtrip Deserialize<'a>>(csr: &CsrMatrix) -> CsrMatrix { +fn json_roundtrip Deserialize<'a>>(csr: &T) -> T { let serialized = serde_json::to_string(csr).unwrap(); - let deserialized: CsrMatrix = serde_json::from_str(&serialized).unwrap(); + let deserialized: T = serde_json::from_str(&serialized).unwrap(); deserialized } +#[test] +fn coo_roundtrip() { + { + // A COO matrix without entries + let matrix = + CooMatrix::::try_from_triplets(3, 2, Vec::new(), Vec::new(), Vec::new()).unwrap(); + + assert_eq!(json_roundtrip(&matrix), matrix); + } + + { + // Arbitrary COO matrix, no duplicates + let i = vec![0, 1, 0, 0, 2]; + let j = vec![0, 2, 1, 3, 3]; + let v = vec![2, 3, 7, 3, 1]; + let matrix = + CooMatrix::::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap(); + + assert_eq!(json_roundtrip(&matrix), matrix); + } +} + +#[test] +fn coo_deserialize_invalid() { + // Valid matrix: {"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,1]} + assert!(serde_json::from_str::>(r#"{"nrows":0,"ncols":0,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,4]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":-3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,4]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,4,5]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,8,3],"values":[2,3,7,3,4]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0],"col_indices":[0,2,1,8,3],"values":[2,3,7,3,4]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":5,"row_indices":[0,10,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,4]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,30,3],"values":[2,3,7,3,4]}"#).is_err()); +} + +#[test] +fn coo_deserialize_duplicates() { + assert_eq!( + serde_json::from_str::>( + r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2,0,1],"col_indices":[0,2,1,3,3,0,2],"values":[2,3,7,3,1,5,6]}"# + ).unwrap(), + CooMatrix::::try_from_triplets( + 3, + 5, + vec![0, 1, 0, 0, 2, 0, 1], + vec![0, 2, 1, 3, 3, 0, 2], + vec![2, 3, 7, 3, 1, 5, 6] + ) + .unwrap() + ); +} + +#[test] +fn csc_roundtrip() { + { + // A CSC matrix with zero explicitly stored entries + let offsets = vec![0, 0, 0, 0]; + let indices = vec![]; + let values = Vec::::new(); + let matrix = CscMatrix::try_from_csc_data(2, 3, offsets, indices, values).unwrap(); + + assert_eq!(json_roundtrip(&matrix), matrix); + } + + { + // An arbitrary CSC matrix + let offsets = vec![0, 2, 2, 5]; + let indices = vec![0, 5, 1, 2, 3]; + let values = vec![0, 1, 2, 3, 4]; + let matrix = + CscMatrix::try_from_csc_data(6, 3, offsets.clone(), indices.clone(), values.clone()) + .unwrap(); + + assert_eq!(json_roundtrip(&matrix), matrix); + } +} + +#[test] +fn csc_deserialize_invalid() { + // Valid matrix: {"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]} + assert!(serde_json::from_str::>(r#"{"nrows":0,"ncols":0,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":-6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4,5]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,8,3],"values":[0,1,2,3,4]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3,1,1],"values":[0,1,2,3,4]}"#).is_err()); + // The following actually panics ('range end index 10 out of range for slice of length 5', nalgebra-sparse\src\pattern.rs:156:38) + //assert!(serde_json::from_str::>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,10,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err()); +} + #[test] fn csr_roundtrip() { { @@ -45,16 +138,25 @@ fn csr_roundtrip() { } #[test] -fn invalid_csr_deserialize() { +fn csr_deserialize_invalid() { // Valid matrix: {"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]} + assert!(serde_json::from_str::>(r#"{"nrows":0,"ncols":0,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":-3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err()); assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4,5]}"#).is_err()); assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,8,3],"values":[0,1,2,3,4]}"#).is_err()); assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3,1,1],"values":[0,1,2,3,4]}"#).is_err()); // The following actually panics ('range end index 10 out of range for slice of length 5', nalgebra-sparse\src\pattern.rs:156:38) //assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,10,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err()); } proptest! { + #[test] + fn csc_roundtrip_proptest(csc in csc_strategy()) { + prop_assert_eq!(json_roundtrip(&csc), csc); + } + #[test] fn csr_roundtrip_proptest(csr in csr_strategy()) { prop_assert_eq!(json_roundtrip(&csr), csr);