Implement Serialize and Deserialize for CsrMatrix
This commit is contained in:
parent
6cc633474d
commit
f9aca24b15
|
@ -15,6 +15,7 @@ license = "Apache-2.0"
|
||||||
[features]
|
[features]
|
||||||
proptest-support = ["proptest", "nalgebra/proptest-support"]
|
proptest-support = ["proptest", "nalgebra/proptest-support"]
|
||||||
compare = [ "matrixcompare-core" ]
|
compare = [ "matrixcompare-core" ]
|
||||||
|
serde-serialize = [ "serde/std" ]
|
||||||
|
|
||||||
# Enable matrix market I/O
|
# Enable matrix market I/O
|
||||||
io = [ "pest", "pest_derive" ]
|
io = [ "pest", "pest_derive" ]
|
||||||
|
@ -29,11 +30,13 @@ proptest = { version = "1.0", optional = true }
|
||||||
matrixcompare-core = { version = "0.1.0", optional = true }
|
matrixcompare-core = { version = "0.1.0", optional = true }
|
||||||
pest = { version = "2", optional = true }
|
pest = { version = "2", optional = true }
|
||||||
pest_derive = { version = "2", optional = true }
|
pest_derive = { version = "2", optional = true }
|
||||||
|
serde = { version = "1.0", default-features = false, features = [ "derive" ], optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
itertools = "0.10"
|
itertools = "0.10"
|
||||||
matrixcompare = { version = "0.3.0", features = [ "proptest-support" ] }
|
matrixcompare = { version = "0.3.0", features = [ "proptest-support" ] }
|
||||||
nalgebra = { version="0.30", path = "../", features = ["compare"] }
|
nalgebra = { version="0.30", path = "../", features = ["compare"] }
|
||||||
|
serde_json = "1.0"
|
||||||
|
|
||||||
[package.metadata.docs.rs]
|
[package.metadata.docs.rs]
|
||||||
# Enable certain features when building docs for docs.rs
|
# Enable certain features when building docs for docs.rs
|
||||||
|
|
|
@ -10,6 +10,9 @@ use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKin
|
||||||
use nalgebra::Scalar;
|
use nalgebra::Scalar;
|
||||||
use num_traits::One;
|
use num_traits::One;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde-serialize")]
|
||||||
|
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
|
||||||
|
|
||||||
use std::iter::FromIterator;
|
use std::iter::FromIterator;
|
||||||
use std::slice::{Iter, IterMut};
|
use std::slice::{Iter, IterMut};
|
||||||
|
|
||||||
|
@ -596,6 +599,67 @@ impl<T> CsrMatrix<T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "serde-serialize", derive(Serialize))]
|
||||||
|
struct CsrMatrixSerializationHelper<'a, T> {
|
||||||
|
nrows: usize,
|
||||||
|
ncols: usize,
|
||||||
|
row_offsets: &'a [usize],
|
||||||
|
col_indices: &'a [usize],
|
||||||
|
values: &'a [T],
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "serde-serialize")]
|
||||||
|
impl<T> Serialize for CsrMatrix<T>
|
||||||
|
where
|
||||||
|
T: Serialize,
|
||||||
|
{
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
CsrMatrixSerializationHelper {
|
||||||
|
nrows: self.nrows(),
|
||||||
|
ncols: self.ncols(),
|
||||||
|
row_offsets: self.row_offsets(),
|
||||||
|
col_indices: self.col_indices(),
|
||||||
|
values: self.values(),
|
||||||
|
}
|
||||||
|
.serialize(serializer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "serde-serialize", derive(Deserialize))]
|
||||||
|
struct CsrMatrixDeserializationHelper<T> {
|
||||||
|
nrows: usize,
|
||||||
|
ncols: usize,
|
||||||
|
row_offsets: Vec<usize>,
|
||||||
|
col_indices: Vec<usize>,
|
||||||
|
values: Vec<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "serde-serialize")]
|
||||||
|
impl<'de, T> Deserialize<'de> for CsrMatrix<T>
|
||||||
|
where
|
||||||
|
T: for<'de2> Deserialize<'de2>,
|
||||||
|
{
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<CsrMatrix<T>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let CsrMatrixDeserializationHelper {
|
||||||
|
nrows,
|
||||||
|
ncols,
|
||||||
|
row_offsets,
|
||||||
|
col_indices,
|
||||||
|
values,
|
||||||
|
} = CsrMatrixDeserializationHelper::deserialize(deserializer)?;
|
||||||
|
CsrMatrix::try_from_csr_data(nrows, ncols, row_offsets, col_indices, values)
|
||||||
|
.map(|m| m.into())
|
||||||
|
// TODO: More specific error
|
||||||
|
.map_err(|_e| de::Error::invalid_value(de::Unexpected::Other("invalid CSR matrix"), &"a valid CSR matrix"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Convert pattern format errors into more meaningful CSR-specific errors.
|
/// Convert pattern format errors into more meaningful CSR-specific errors.
|
||||||
///
|
///
|
||||||
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
#![cfg(feature = "serde-serialize")]
|
||||||
|
//! Serialization tests
|
||||||
|
#[cfg(any(not(feature = "proptest-support"), not(feature = "compare")))]
|
||||||
|
compile_error!("Tests must be run with features `proptest-support` and `compare`");
|
||||||
|
|
||||||
|
#[macro_use]
|
||||||
|
pub mod common;
|
||||||
|
|
||||||
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
|
||||||
|
use proptest::prelude::*;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::common::csr_strategy;
|
||||||
|
|
||||||
|
fn json_roundtrip<T: Serialize + for<'a> Deserialize<'a>>(csr: &CsrMatrix<T>) -> CsrMatrix<T> {
|
||||||
|
let serialized = serde_json::to_string(csr).unwrap();
|
||||||
|
let deserialized: CsrMatrix<T> = serde_json::from_str(&serialized).unwrap();
|
||||||
|
deserialized
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_roundtrip() {
|
||||||
|
{
|
||||||
|
// A CSR matrix with zero explicitly stored entries
|
||||||
|
let offsets = vec![0, 0, 0, 0];
|
||||||
|
let indices = vec![];
|
||||||
|
let values = Vec::<i32>::new();
|
||||||
|
let matrix = CsrMatrix::try_from_csr_data(3, 2, offsets, indices, values).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(json_roundtrip(&matrix), matrix);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// An arbitrary CSR matrix
|
||||||
|
let offsets = vec![0, 2, 2, 5];
|
||||||
|
let indices = vec![0, 5, 1, 2, 3];
|
||||||
|
let values = vec![0, 1, 2, 3, 4];
|
||||||
|
let matrix =
|
||||||
|
CsrMatrix::try_from_csr_data(3, 6, offsets.clone(), indices.clone(), values.clone())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(json_roundtrip(&matrix), matrix);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_csr_deserialize() {
|
||||||
|
// Valid matrix: {"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}
|
||||||
|
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3]}"#).is_err());
|
||||||
|
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,8,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||||
|
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3,1,1],"values":[0,1,2,3,4]}"#).is_err());
|
||||||
|
// The following actually panics ('range end index 10 out of range for slice of length 5', nalgebra-sparse\src\pattern.rs:156:38)
|
||||||
|
//assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,10,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
proptest! {
|
||||||
|
#[test]
|
||||||
|
fn csr_roundtrip_proptest(csr in csr_strategy()) {
|
||||||
|
prop_assert_eq!(json_roundtrip(&csr), csr);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue