forked from M-Labs/nalgebra
Implement Serialize, Deserialize for Csc, Coo; move helper out of impls
This commit is contained in:
parent
18b694dad2
commit
40d8a904a3
@ -2,6 +2,9 @@
|
||||
|
||||
use crate::SparseFormatError;
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
/// A COO representation of a sparse matrix.
|
||||
///
|
||||
/// A COO matrix stores entries in coordinate-form, that is triplets `(i, j, v)`, where `i` and `j`
|
||||
@ -273,3 +276,60 @@ impl<T> CooMatrix<T> {
|
||||
(self.row_indices, self.col_indices, self.values)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
#[derive(Serialize)]
|
||||
struct CooMatrixSerializationData<'a, T> {
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
row_indices: &'a [usize],
|
||||
col_indices: &'a [usize],
|
||||
values: &'a [T],
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
impl<T> Serialize for CooMatrix<T>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
CooMatrixSerializationData {
|
||||
nrows: self.nrows(),
|
||||
ncols: self.ncols(),
|
||||
row_indices: self.row_indices(),
|
||||
col_indices: self.col_indices(),
|
||||
values: self.values(),
|
||||
}
|
||||
.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
#[derive(Deserialize)]
|
||||
struct CooMatrixDeserializationData<T> {
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
row_indices: Vec<usize>,
|
||||
col_indices: Vec<usize>,
|
||||
values: Vec<T>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
impl<'de, T> Deserialize<'de> for CooMatrix<T>
|
||||
where
|
||||
T: for<'de2> Deserialize<'de2>,
|
||||
{
|
||||
fn deserialize<D>(deserializer: D) -> Result<CooMatrix<T>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let de = CooMatrixDeserializationData::deserialize(deserializer)?;
|
||||
CooMatrix::try_from_triplets(de.nrows, de.ncols, de.row_indices, de.col_indices, de.values)
|
||||
.map(|m| m.into())
|
||||
// TODO: More specific error
|
||||
.map_err(|_e| de::Error::invalid_value(de::Unexpected::Other("invalid COO matrix"), &"a valid COO matrix"))
|
||||
}
|
||||
}
|
||||
|
@ -10,6 +10,8 @@ use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKin
|
||||
|
||||
use nalgebra::Scalar;
|
||||
use num_traits::One;
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::slice::{Iter, IterMut};
|
||||
|
||||
/// A CSC representation of a sparse matrix.
|
||||
@ -524,6 +526,63 @@ impl<T> CscMatrix<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
#[derive(Serialize)]
|
||||
struct CscMatrixSerializationData<'a, T> {
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
col_offsets: &'a [usize],
|
||||
row_indices: &'a [usize],
|
||||
values: &'a [T],
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
impl<T> Serialize for CscMatrix<T>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
CscMatrixSerializationData {
|
||||
nrows: self.nrows(),
|
||||
ncols: self.ncols(),
|
||||
col_offsets: self.col_offsets(),
|
||||
row_indices: self.row_indices(),
|
||||
values: self.values(),
|
||||
}
|
||||
.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
#[derive(Deserialize)]
|
||||
struct CscMatrixDeserializationData<T> {
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
col_offsets: Vec<usize>,
|
||||
row_indices: Vec<usize>,
|
||||
values: Vec<T>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
impl<'de, T> Deserialize<'de> for CscMatrix<T>
|
||||
where
|
||||
T: for<'de2> Deserialize<'de2>,
|
||||
{
|
||||
fn deserialize<D>(deserializer: D) -> Result<CscMatrix<T>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let de = CscMatrixDeserializationData::deserialize(deserializer)?;
|
||||
CscMatrix::try_from_csc_data(de.nrows, de.ncols, de.col_offsets, de.row_indices, de.values)
|
||||
.map(|m| m.into())
|
||||
// TODO: More specific error
|
||||
.map_err(|_e| de::Error::invalid_value(de::Unexpected::Other("invalid CSC matrix"), &"a valid CSC matrix"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert pattern format errors into more meaningful CSC-specific errors.
|
||||
///
|
||||
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
||||
|
@ -9,10 +9,8 @@ use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKin
|
||||
|
||||
use nalgebra::Scalar;
|
||||
use num_traits::One;
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
use std::iter::FromIterator;
|
||||
use std::slice::{Iter, IterMut};
|
||||
|
||||
@ -600,14 +598,6 @@ impl<T> CsrMatrix<T> {
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
impl<T> Serialize for CsrMatrix<T>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
#[derive(Serialize)]
|
||||
struct CsrMatrixSerializationData<'a, T> {
|
||||
nrows: usize,
|
||||
@ -617,6 +607,15 @@ where
|
||||
values: &'a [T],
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
impl<T> Serialize for CsrMatrix<T>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
CsrMatrixSerializationData {
|
||||
nrows: self.nrows(),
|
||||
ncols: self.ncols(),
|
||||
@ -629,14 +628,6 @@ where
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
impl<'de, T> Deserialize<'de> for CsrMatrix<T>
|
||||
where
|
||||
T: for<'de2> Deserialize<'de2>,
|
||||
{
|
||||
fn deserialize<D>(deserializer: D) -> Result<CsrMatrix<T>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
struct CsrMatrixDeserializationData<T> {
|
||||
nrows: usize,
|
||||
@ -646,6 +637,15 @@ where
|
||||
values: Vec<T>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
impl<'de, T> Deserialize<'de> for CsrMatrix<T>
|
||||
where
|
||||
T: for<'de2> Deserialize<'de2>,
|
||||
{
|
||||
fn deserialize<D>(deserializer: D) -> Result<CsrMatrix<T>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let de = CsrMatrixDeserializationData::deserialize(deserializer)?;
|
||||
CsrMatrix::try_from_csr_data(de.nrows, de.ncols, de.row_offsets, de.col_indices, de.values)
|
||||
.map(|m| m.into())
|
||||
|
@ -6,19 +6,112 @@ compile_error!("Tests must be run with features `proptest-support` and `compare`
|
||||
#[macro_use]
|
||||
pub mod common;
|
||||
|
||||
use nalgebra_sparse::coo::CooMatrix;
|
||||
use nalgebra_sparse::csc::CscMatrix;
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
|
||||
use proptest::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::common::csr_strategy;
|
||||
use crate::common::{csc_strategy, csr_strategy};
|
||||
|
||||
fn json_roundtrip<T: Serialize + for<'a> Deserialize<'a>>(csr: &CsrMatrix<T>) -> CsrMatrix<T> {
|
||||
fn json_roundtrip<T: Serialize + for<'a> Deserialize<'a>>(csr: &T) -> T {
|
||||
let serialized = serde_json::to_string(csr).unwrap();
|
||||
let deserialized: CsrMatrix<T> = serde_json::from_str(&serialized).unwrap();
|
||||
let deserialized: T = serde_json::from_str(&serialized).unwrap();
|
||||
deserialized
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coo_roundtrip() {
|
||||
{
|
||||
// A COO matrix without entries
|
||||
let matrix =
|
||||
CooMatrix::<i32>::try_from_triplets(3, 2, Vec::new(), Vec::new(), Vec::new()).unwrap();
|
||||
|
||||
assert_eq!(json_roundtrip(&matrix), matrix);
|
||||
}
|
||||
|
||||
{
|
||||
// Arbitrary COO matrix, no duplicates
|
||||
let i = vec![0, 1, 0, 0, 2];
|
||||
let j = vec![0, 2, 1, 3, 3];
|
||||
let v = vec![2, 3, 7, 3, 1];
|
||||
let matrix =
|
||||
CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap();
|
||||
|
||||
assert_eq!(json_roundtrip(&matrix), matrix);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coo_deserialize_invalid() {
|
||||
// Valid matrix: {"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,1]}
|
||||
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":0,"ncols":0,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":-3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,4,5]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,8,3],"values":[2,3,7,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0],"col_indices":[0,2,1,8,3],"values":[2,3,7,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":3,"ncols":5,"row_indices":[0,10,0,0,2],"col_indices":[0,2,1,3,3],"values":[2,3,7,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CooMatrix<i32>>(r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2],"col_indices":[0,2,1,30,3],"values":[2,3,7,3,4]}"#).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coo_deserialize_duplicates() {
|
||||
assert_eq!(
|
||||
serde_json::from_str::<CooMatrix<i32>>(
|
||||
r#"{"nrows":3,"ncols":5,"row_indices":[0,1,0,0,2,0,1],"col_indices":[0,2,1,3,3,0,2],"values":[2,3,7,3,1,5,6]}"#
|
||||
).unwrap(),
|
||||
CooMatrix::<i32>::try_from_triplets(
|
||||
3,
|
||||
5,
|
||||
vec![0, 1, 0, 0, 2, 0, 1],
|
||||
vec![0, 2, 1, 3, 3, 0, 2],
|
||||
vec![2, 3, 7, 3, 1, 5, 6]
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_roundtrip() {
|
||||
{
|
||||
// A CSC matrix with zero explicitly stored entries
|
||||
let offsets = vec![0, 0, 0, 0];
|
||||
let indices = vec![];
|
||||
let values = Vec::<i32>::new();
|
||||
let matrix = CscMatrix::try_from_csc_data(2, 3, offsets, indices, values).unwrap();
|
||||
|
||||
assert_eq!(json_roundtrip(&matrix), matrix);
|
||||
}
|
||||
|
||||
{
|
||||
// An arbitrary CSC matrix
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix =
|
||||
CscMatrix::try_from_csc_data(6, 3, offsets.clone(), indices.clone(), values.clone())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(json_roundtrip(&matrix), matrix);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_deserialize_invalid() {
|
||||
// Valid matrix: {"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}
|
||||
assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":0,"ncols":0,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":-6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4,5]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,8,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3,1,1],"values":[0,1,2,3,4]}"#).is_err());
|
||||
// The following actually panics ('range end index 10 out of range for slice of length 5', nalgebra-sparse\src\pattern.rs:156:38)
|
||||
//assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,10,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CscMatrix<i32>>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_roundtrip() {
|
||||
{
|
||||
@ -45,16 +138,25 @@ fn csr_roundtrip() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_csr_deserialize() {
|
||||
fn csr_deserialize_invalid() {
|
||||
// Valid matrix: {"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}
|
||||
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":0,"ncols":0,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":-3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4,5]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,8,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,2,2,5],"col_indices":[0,5,1,2,3,1,1],"values":[0,1,2,3,4]}"#).is_err());
|
||||
// The following actually panics ('range end index 10 out of range for slice of length 5', nalgebra-sparse\src\pattern.rs:156:38)
|
||||
//assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":3,"ncols":6,"row_offsets":[0,10,2,5],"col_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||
assert!(serde_json::from_str::<CsrMatrix<i32>>(r#"{"nrows":6,"ncols":3,"col_offsets":[0,2,2,5],"row_indices":[0,5,1,2,3],"values":[0,1,2,3,4]}"#).is_err());
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn csc_roundtrip_proptest(csc in csc_strategy()) {
|
||||
prop_assert_eq!(json_roundtrip(&csc), csc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_roundtrip_proptest(csr in csr_strategy()) {
|
||||
prop_assert_eq!(json_roundtrip(&csr), csr);
|
||||
|
Loading…
Reference in New Issue
Block a user