Unify separate (de)serialization helper structs by using Cow<'a, [T]>

This commit is contained in:
Fabian Löschner 2021-12-09 18:18:32 +01:00 committed by Fabian Loeschner
parent 3be81be2e3
commit a8fa7f71c0
4 changed files with 54 additions and 93 deletions

View File

@ -1,6 +1,7 @@
//! An implementation of the COO sparse matrix format. //! An implementation of the COO sparse matrix format.
use crate::SparseFormatError; use crate::SparseFormatError;
use std::borrow::Cow;
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
@ -278,19 +279,19 @@ impl<T> CooMatrix<T> {
} }
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
#[derive(Serialize)] #[derive(Serialize, Deserialize)]
struct CooMatrixSerializationData<'a, T> { struct CooMatrixSerializationData<'a, T: Clone> {
nrows: usize, nrows: usize,
ncols: usize, ncols: usize,
row_indices: &'a [usize], row_indices: Cow<'a, [usize]>,
col_indices: &'a [usize], col_indices: Cow<'a, [usize]>,
values: &'a [T], values: Cow<'a, [T]>,
} }
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
impl<T> Serialize for CooMatrix<T> impl<T> Serialize for CooMatrix<T>
where where
T: Serialize, T: Serialize + Clone,
{ {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where
@ -299,42 +300,31 @@ where
CooMatrixSerializationData { CooMatrixSerializationData {
nrows: self.nrows(), nrows: self.nrows(),
ncols: self.ncols(), ncols: self.ncols(),
row_indices: self.row_indices(), row_indices: self.row_indices().into(),
col_indices: self.col_indices(), col_indices: self.col_indices().into(),
values: self.values(), values: self.values().into(),
} }
.serialize(serializer) .serialize(serializer)
} }
} }
#[cfg(feature = "serde-serialize")]
#[derive(Deserialize)]
struct CooMatrixDeserializationData<T> {
nrows: usize,
ncols: usize,
row_indices: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
}
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
impl<'de, T> Deserialize<'de> for CooMatrix<T> impl<'de, T> Deserialize<'de> for CooMatrix<T>
where where
T: Deserialize<'de>, T: Deserialize<'de> + Clone,
{ {
fn deserialize<D>(deserializer: D) -> Result<CooMatrix<T>, D::Error> fn deserialize<D>(deserializer: D) -> Result<CooMatrix<T>, D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let de = CooMatrixDeserializationData::deserialize(deserializer)?; let de = CooMatrixSerializationData::deserialize(deserializer)?;
CooMatrix::try_from_triplets( CooMatrix::try_from_triplets(
de.nrows, de.nrows,
de.ncols, de.ncols,
de.row_indices, de.row_indices.into(),
de.col_indices, de.col_indices.into(),
de.values, de.values.into(),
) )
.map(|m| m.into())
.map_err(|e| de::Error::custom(e)) .map_err(|e| de::Error::custom(e))
} }
} }

View File

@ -7,6 +7,7 @@ use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
use crate::csr::CsrMatrix; use crate::csr::CsrMatrix;
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind}; use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
use std::borrow::Cow;
use nalgebra::Scalar; use nalgebra::Scalar;
use num_traits::One; use num_traits::One;
@ -523,19 +524,19 @@ impl<T> CscMatrix<T> {
} }
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
#[derive(Serialize)] #[derive(Serialize, Deserialize)]
struct CscMatrixSerializationData<'a, T> { struct CscMatrixSerializationData<'a, T: Clone> {
nrows: usize, nrows: usize,
ncols: usize, ncols: usize,
col_offsets: &'a [usize], col_offsets: Cow<'a, [usize]>,
row_indices: &'a [usize], row_indices: Cow<'a, [usize]>,
values: &'a [T], values: Cow<'a, [T]>,
} }
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
impl<T> Serialize for CscMatrix<T> impl<T> Serialize for CscMatrix<T>
where where
T: Serialize, T: Serialize + Clone,
{ {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where
@ -544,42 +545,31 @@ where
CscMatrixSerializationData { CscMatrixSerializationData {
nrows: self.nrows(), nrows: self.nrows(),
ncols: self.ncols(), ncols: self.ncols(),
col_offsets: self.col_offsets(), col_offsets: Cow::Borrowed(self.col_offsets()),
row_indices: self.row_indices(), row_indices: Cow::Borrowed(self.row_indices()),
values: self.values(), values: Cow::Borrowed(self.values()),
} }
.serialize(serializer) .serialize(serializer)
} }
} }
#[cfg(feature = "serde-serialize")]
#[derive(Deserialize)]
struct CscMatrixDeserializationData<T> {
nrows: usize,
ncols: usize,
col_offsets: Vec<usize>,
row_indices: Vec<usize>,
values: Vec<T>,
}
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
impl<'de, T> Deserialize<'de> for CscMatrix<T> impl<'de, T> Deserialize<'de> for CscMatrix<T>
where where
T: Deserialize<'de>, T: Deserialize<'de> + Clone,
{ {
fn deserialize<D>(deserializer: D) -> Result<CscMatrix<T>, D::Error> fn deserialize<D>(deserializer: D) -> Result<CscMatrix<T>, D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let de = CscMatrixDeserializationData::deserialize(deserializer)?; let de = CscMatrixSerializationData::deserialize(deserializer)?;
CscMatrix::try_from_csc_data( CscMatrix::try_from_csc_data(
de.nrows, de.nrows,
de.ncols, de.ncols,
de.col_offsets, de.col_offsets.into(),
de.row_indices, de.row_indices.into(),
de.values, de.values.into(),
) )
.map(|m| m.into())
.map_err(|e| de::Error::custom(e)) .map_err(|e| de::Error::custom(e))
} }
} }

View File

@ -6,6 +6,7 @@ use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind}; use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
use std::borrow::Cow;
use nalgebra::Scalar; use nalgebra::Scalar;
use num_traits::One; use num_traits::One;
@ -594,19 +595,19 @@ impl<T> CsrMatrix<T> {
} }
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
#[derive(Serialize)] #[derive(Serialize, Deserialize)]
struct CsrMatrixSerializationData<'a, T> { struct CsrMatrixSerializationData<'a, T: Clone> {
nrows: usize, nrows: usize,
ncols: usize, ncols: usize,
row_offsets: &'a [usize], row_offsets: Cow<'a, [usize]>,
col_indices: &'a [usize], col_indices: Cow<'a, [usize]>,
values: &'a [T], values: Cow<'a, [T]>,
} }
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
impl<T> Serialize for CsrMatrix<T> impl<T> Serialize for CsrMatrix<T>
where where
T: Serialize, T: Serialize + Clone,
{ {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where
@ -615,42 +616,31 @@ where
CsrMatrixSerializationData { CsrMatrixSerializationData {
nrows: self.nrows(), nrows: self.nrows(),
ncols: self.ncols(), ncols: self.ncols(),
row_offsets: self.row_offsets(), row_offsets: Cow::Borrowed(self.row_offsets()),
col_indices: self.col_indices(), col_indices: Cow::Borrowed(self.col_indices()),
values: self.values(), values: Cow::Borrowed(self.values()),
} }
.serialize(serializer) .serialize(serializer)
} }
} }
#[cfg(feature = "serde-serialize")]
#[derive(Deserialize)]
struct CsrMatrixDeserializationData<T> {
nrows: usize,
ncols: usize,
row_offsets: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
}
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
impl<'de, T> Deserialize<'de> for CsrMatrix<T> impl<'de, T> Deserialize<'de> for CsrMatrix<T>
where where
T: Deserialize<'de>, T: Deserialize<'de> + Clone,
{ {
fn deserialize<D>(deserializer: D) -> Result<CsrMatrix<T>, D::Error> fn deserialize<D>(deserializer: D) -> Result<CsrMatrix<T>, D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let de = CsrMatrixDeserializationData::deserialize(deserializer)?; let de = CsrMatrixSerializationData::deserialize(deserializer)?;
CsrMatrix::try_from_csr_data( CsrMatrix::try_from_csr_data(
de.nrows, de.nrows,
de.ncols, de.ncols,
de.row_offsets, de.row_offsets.into(),
de.col_indices, de.col_indices.into(),
de.values, de.values.into(),
) )
.map(|m| m.into())
.map_err(|e| de::Error::custom(e)) .map_err(|e| de::Error::custom(e))
} }
} }

View File

@ -1,6 +1,7 @@
//! Sparsity patterns for CSR and CSC matrices. //! Sparsity patterns for CSR and CSC matrices.
use crate::cs::transpose_cs; use crate::cs::transpose_cs;
use crate::SparseFormatError; use crate::SparseFormatError;
use std::borrow::Cow;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
@ -293,12 +294,12 @@ pub enum SparsityPatternFormatError {
} }
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
#[derive(Serialize)] #[derive(Serialize, Deserialize)]
struct SparsityPatternSerializationData<'a> { struct SparsityPatternSerializationData<'a> {
major_dim: usize, major_dim: usize,
minor_dim: usize, minor_dim: usize,
major_offsets: &'a [usize], major_offsets: Cow<'a, [usize]>,
minor_indices: &'a [usize], minor_indices: Cow<'a, [usize]>,
} }
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
@ -310,36 +311,26 @@ impl Serialize for SparsityPattern {
SparsityPatternSerializationData { SparsityPatternSerializationData {
major_dim: self.major_dim(), major_dim: self.major_dim(),
minor_dim: self.minor_dim(), minor_dim: self.minor_dim(),
major_offsets: self.major_offsets(), major_offsets: Cow::Borrowed(self.major_offsets()),
minor_indices: self.minor_indices(), minor_indices: Cow::Borrowed(self.minor_indices()),
} }
.serialize(serializer) .serialize(serializer)
} }
} }
#[cfg(feature = "serde-serialize")]
#[derive(Deserialize)]
struct SparsityPatternDeserializationData {
major_dim: usize,
minor_dim: usize,
major_offsets: Vec<usize>,
minor_indices: Vec<usize>,
}
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
impl<'de> Deserialize<'de> for SparsityPattern { impl<'de> Deserialize<'de> for SparsityPattern {
fn deserialize<D>(deserializer: D) -> Result<SparsityPattern, D::Error> fn deserialize<D>(deserializer: D) -> Result<SparsityPattern, D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let de = SparsityPatternDeserializationData::deserialize(deserializer)?; let de = SparsityPatternSerializationData::deserialize(deserializer)?;
SparsityPattern::try_from_offsets_and_indices( SparsityPattern::try_from_offsets_and_indices(
de.major_dim, de.major_dim,
de.minor_dim, de.minor_dim,
de.major_offsets, de.major_offsets.into(),
de.minor_indices, de.minor_indices.into(),
) )
.map(|m| m.into())
.map_err(|e| de::Error::custom(e)) .map_err(|e| de::Error::custom(e))
} }
} }