From f9aca24b1567115e23d0968d227f0b2e7f1f97ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20L=C3=B6schner?= Date: Sun, 7 Nov 2021 21:32:11 +0100 Subject: [PATCH] Implement Serialize and Deserialize for CsrMatrix --- nalgebra-sparse/Cargo.toml | 5 ++- nalgebra-sparse/src/csr.rs | 64 ++++++++++++++++++++++++++++++++++ nalgebra-sparse/tests/serde.rs | 62 ++++++++++++++++++++++++++++++++ 3 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 nalgebra-sparse/tests/serde.rs diff --git a/nalgebra-sparse/Cargo.toml b/nalgebra-sparse/Cargo.toml index 6f7a7b4a..eec7326d 100644 --- a/nalgebra-sparse/Cargo.toml +++ b/nalgebra-sparse/Cargo.toml @@ -15,6 +15,7 @@ license = "Apache-2.0" [features] proptest-support = ["proptest", "nalgebra/proptest-support"] compare = [ "matrixcompare-core" ] +serde-serialize = [ "serde/std" ] # Enable matrix market I/O io = [ "pest", "pest_derive" ] @@ -29,12 +30,14 @@ proptest = { version = "1.0", optional = true } matrixcompare-core = { version = "0.1.0", optional = true } pest = { version = "2", optional = true } pest_derive = { version = "2", optional = true } +serde = { version = "1.0", default-features = false, features = [ "derive" ], optional = true } [dev-dependencies] itertools = "0.10" matrixcompare = { version = "0.3.0", features = [ "proptest-support" ] } nalgebra = { version="0.30", path = "../", features = ["compare"] } +serde_json = "1.0" [package.metadata.docs.rs] # Enable certain features when building docs for docs.rs -features = [ "proptest-support", "compare" ] \ No newline at end of file +features = [ "proptest-support", "compare" ] diff --git a/nalgebra-sparse/src/csr.rs b/nalgebra-sparse/src/csr.rs index 4324d18d..b98717b8 100644 --- a/nalgebra-sparse/src/csr.rs +++ b/nalgebra-sparse/src/csr.rs @@ -10,6 +10,9 @@ use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKin use nalgebra::Scalar; use num_traits::One; +#[cfg(feature = "serde-serialize")] +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; + use std::iter::FromIterator; use std::slice::{Iter, IterMut}; @@ -596,6 +599,67 @@ impl CsrMatrix { } } +#[cfg_attr(feature = "serde-serialize", derive(Serialize))] +struct CsrMatrixSerializationHelper<'a, T> { + nrows: usize, + ncols: usize, + row_offsets: &'a [usize], + col_indices: &'a [usize], + values: &'a [T], +} + +#[cfg(feature = "serde-serialize")] +impl Serialize for CsrMatrix +where + T: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + CsrMatrixSerializationHelper { + nrows: self.nrows(), + ncols: self.ncols(), + row_offsets: self.row_offsets(), + col_indices: self.col_indices(), + values: self.values(), + } + .serialize(serializer) + } +} + +#[cfg_attr(feature = "serde-serialize", derive(Deserialize))] +struct CsrMatrixDeserializationHelper { + nrows: usize, + ncols: usize, + row_offsets: Vec, + col_indices: Vec, + values: Vec, +} + +#[cfg(feature = "serde-serialize")] +impl<'de, T> Deserialize<'de> for CsrMatrix +where + T: for<'de2> Deserialize<'de2>, +{ + fn deserialize(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let CsrMatrixDeserializationHelper { + nrows, + ncols, + row_offsets, + col_indices, + values, + } = CsrMatrixDeserializationHelper::deserialize(deserializer)?; + CsrMatrix::try_from_csr_data(nrows, ncols, row_offsets, col_indices, values) + .map(|m| m.into()) + // TODO: More specific error + .map_err(|_e| de::Error::invalid_value(de::Unexpected::Other("invalid CSR matrix"), &"a valid CSR matrix")) + } +} + /// Convert pattern format errors into more meaningful CSR-specific errors. /// /// This ensures that the terminology is consistent: we are talking about rows and columns, diff --git a/nalgebra-sparse/tests/serde.rs b/nalgebra-sparse/tests/serde.rs new file mode 100644 index 00000000..ecee76d1 --- /dev/null +++ b/nalgebra-sparse/tests/serde.rs @@ -0,0 +1,62 @@ +#![cfg(feature = "serde-serialize")] +//! Serialization tests +#[cfg(any(not(feature = "proptest-support"), not(feature = "compare")))] +compile_error!("Tests must be run with features `proptest-support` and `compare`"); + +#[macro_use] +pub mod common; + +use nalgebra_sparse::csr::CsrMatrix; + +use proptest::prelude::*; +use serde::{Deserialize, Serialize}; + +use crate::common::csr_strategy; + +fn json_roundtrip Deserialize<'a>>(csr: &CsrMatrix) -> CsrMatrix { + let serialized = serde_json::to_string(csr).unwrap(); + let deserialized: CsrMatrix = serde_json::from_str(&serialized).unwrap(); + deserialized +} + +#[test] +fn csr_roundtrip() { + { + // A CSR matrix with zero explicitly stored entries + let offsets = vec![0, 0, 0, 0]; + let indices = vec![]; + let values = Vec::::new(); + let matrix = CsrMatrix::try_from_csr_data(3, 2, offsets, indices, values).unwrap(); + + assert_eq!(json_roundtrip(&matrix), matrix); + } + + { + // An arbitrary CSR matrix + let offsets = vec![0, 2, 2, 5]; + let indices = vec![0, 5, 1, 2, 3]; + let values = vec![0, 1, 2, 3, 4]; + let matrix = + CsrMatrix::try_from_csr_data(3, 6, offsets.clone(), indices.clone(), values.clone()) + .unwrap(); + + assert_eq!(json_roundtrip(&matrix), matrix); + } +} + +#[test] +fn invalid_csr_deserialize() { + // Valid matrix: {"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]} + assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,8,3],"values":[0,1,2,3,4]}"#).is_err()); + assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3,1,1],"values":[0,1,2,3,4]}"#).is_err()); + // The following actually panics ('range end index 10 out of range for slice of length 5', nalgebra-sparse\src\pattern.rs:156:38) + //assert!(serde_json::from_str::>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,10,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err()); +} + +proptest! { + #[test] + fn csr_roundtrip_proptest(csr in csr_strategy()) { + prop_assert_eq!(json_roundtrip(&csr), csr); + } +}