Unify separate (de)serialization helper structs by using Cow<'a, [T]>
This commit is contained in:
parent
3be81be2e3
commit
a8fa7f71c0
|
@ -1,6 +1,7 @@
|
||||||
//! An implementation of the COO sparse matrix format.
|
//! An implementation of the COO sparse matrix format.
|
||||||
|
|
||||||
use crate::SparseFormatError;
|
use crate::SparseFormatError;
|
||||||
|
use std::borrow::Cow;
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
#[cfg(feature = "serde-serialize")]
|
||||||
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
|
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
|
||||||
|
@ -278,19 +279,19 @@ impl<T> CooMatrix<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
#[cfg(feature = "serde-serialize")]
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
struct CooMatrixSerializationData<'a, T> {
|
struct CooMatrixSerializationData<'a, T: Clone> {
|
||||||
nrows: usize,
|
nrows: usize,
|
||||||
ncols: usize,
|
ncols: usize,
|
||||||
row_indices: &'a [usize],
|
row_indices: Cow<'a, [usize]>,
|
||||||
col_indices: &'a [usize],
|
col_indices: Cow<'a, [usize]>,
|
||||||
values: &'a [T],
|
values: Cow<'a, [T]>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
#[cfg(feature = "serde-serialize")]
|
||||||
impl<T> Serialize for CooMatrix<T>
|
impl<T> Serialize for CooMatrix<T>
|
||||||
where
|
where
|
||||||
T: Serialize,
|
T: Serialize + Clone,
|
||||||
{
|
{
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
where
|
where
|
||||||
|
@ -299,42 +300,31 @@ where
|
||||||
CooMatrixSerializationData {
|
CooMatrixSerializationData {
|
||||||
nrows: self.nrows(),
|
nrows: self.nrows(),
|
||||||
ncols: self.ncols(),
|
ncols: self.ncols(),
|
||||||
row_indices: self.row_indices(),
|
row_indices: self.row_indices().into(),
|
||||||
col_indices: self.col_indices(),
|
col_indices: self.col_indices().into(),
|
||||||
values: self.values(),
|
values: self.values().into(),
|
||||||
}
|
}
|
||||||
.serialize(serializer)
|
.serialize(serializer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct CooMatrixDeserializationData<T> {
|
|
||||||
nrows: usize,
|
|
||||||
ncols: usize,
|
|
||||||
row_indices: Vec<usize>,
|
|
||||||
col_indices: Vec<usize>,
|
|
||||||
values: Vec<T>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
#[cfg(feature = "serde-serialize")]
|
||||||
impl<'de, T> Deserialize<'de> for CooMatrix<T>
|
impl<'de, T> Deserialize<'de> for CooMatrix<T>
|
||||||
where
|
where
|
||||||
T: Deserialize<'de>,
|
T: Deserialize<'de> + Clone,
|
||||||
{
|
{
|
||||||
fn deserialize<D>(deserializer: D) -> Result<CooMatrix<T>, D::Error>
|
fn deserialize<D>(deserializer: D) -> Result<CooMatrix<T>, D::Error>
|
||||||
where
|
where
|
||||||
D: Deserializer<'de>,
|
D: Deserializer<'de>,
|
||||||
{
|
{
|
||||||
let de = CooMatrixDeserializationData::deserialize(deserializer)?;
|
let de = CooMatrixSerializationData::deserialize(deserializer)?;
|
||||||
CooMatrix::try_from_triplets(
|
CooMatrix::try_from_triplets(
|
||||||
de.nrows,
|
de.nrows,
|
||||||
de.ncols,
|
de.ncols,
|
||||||
de.row_indices,
|
de.row_indices.into(),
|
||||||
de.col_indices,
|
de.col_indices.into(),
|
||||||
de.values,
|
de.values.into(),
|
||||||
)
|
)
|
||||||
.map(|m| m.into())
|
|
||||||
.map_err(|e| de::Error::custom(e))
|
.map_err(|e| de::Error::custom(e))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||||
|
use std::borrow::Cow;
|
||||||
|
|
||||||
use nalgebra::Scalar;
|
use nalgebra::Scalar;
|
||||||
use num_traits::One;
|
use num_traits::One;
|
||||||
|
@ -523,19 +524,19 @@ impl<T> CscMatrix<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
#[cfg(feature = "serde-serialize")]
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
struct CscMatrixSerializationData<'a, T> {
|
struct CscMatrixSerializationData<'a, T: Clone> {
|
||||||
nrows: usize,
|
nrows: usize,
|
||||||
ncols: usize,
|
ncols: usize,
|
||||||
col_offsets: &'a [usize],
|
col_offsets: Cow<'a, [usize]>,
|
||||||
row_indices: &'a [usize],
|
row_indices: Cow<'a, [usize]>,
|
||||||
values: &'a [T],
|
values: Cow<'a, [T]>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
#[cfg(feature = "serde-serialize")]
|
||||||
impl<T> Serialize for CscMatrix<T>
|
impl<T> Serialize for CscMatrix<T>
|
||||||
where
|
where
|
||||||
T: Serialize,
|
T: Serialize + Clone,
|
||||||
{
|
{
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
where
|
where
|
||||||
|
@ -544,42 +545,31 @@ where
|
||||||
CscMatrixSerializationData {
|
CscMatrixSerializationData {
|
||||||
nrows: self.nrows(),
|
nrows: self.nrows(),
|
||||||
ncols: self.ncols(),
|
ncols: self.ncols(),
|
||||||
col_offsets: self.col_offsets(),
|
col_offsets: Cow::Borrowed(self.col_offsets()),
|
||||||
row_indices: self.row_indices(),
|
row_indices: Cow::Borrowed(self.row_indices()),
|
||||||
values: self.values(),
|
values: Cow::Borrowed(self.values()),
|
||||||
}
|
}
|
||||||
.serialize(serializer)
|
.serialize(serializer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct CscMatrixDeserializationData<T> {
|
|
||||||
nrows: usize,
|
|
||||||
ncols: usize,
|
|
||||||
col_offsets: Vec<usize>,
|
|
||||||
row_indices: Vec<usize>,
|
|
||||||
values: Vec<T>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
#[cfg(feature = "serde-serialize")]
|
||||||
impl<'de, T> Deserialize<'de> for CscMatrix<T>
|
impl<'de, T> Deserialize<'de> for CscMatrix<T>
|
||||||
where
|
where
|
||||||
T: Deserialize<'de>,
|
T: Deserialize<'de> + Clone,
|
||||||
{
|
{
|
||||||
fn deserialize<D>(deserializer: D) -> Result<CscMatrix<T>, D::Error>
|
fn deserialize<D>(deserializer: D) -> Result<CscMatrix<T>, D::Error>
|
||||||
where
|
where
|
||||||
D: Deserializer<'de>,
|
D: Deserializer<'de>,
|
||||||
{
|
{
|
||||||
let de = CscMatrixDeserializationData::deserialize(deserializer)?;
|
let de = CscMatrixSerializationData::deserialize(deserializer)?;
|
||||||
CscMatrix::try_from_csc_data(
|
CscMatrix::try_from_csc_data(
|
||||||
de.nrows,
|
de.nrows,
|
||||||
de.ncols,
|
de.ncols,
|
||||||
de.col_offsets,
|
de.col_offsets.into(),
|
||||||
de.row_indices,
|
de.row_indices.into(),
|
||||||
de.values,
|
de.values.into(),
|
||||||
)
|
)
|
||||||
.map(|m| m.into())
|
|
||||||
.map_err(|e| de::Error::custom(e))
|
.map_err(|e| de::Error::custom(e))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||||
|
use std::borrow::Cow;
|
||||||
|
|
||||||
use nalgebra::Scalar;
|
use nalgebra::Scalar;
|
||||||
use num_traits::One;
|
use num_traits::One;
|
||||||
|
@ -594,19 +595,19 @@ impl<T> CsrMatrix<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
#[cfg(feature = "serde-serialize")]
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
struct CsrMatrixSerializationData<'a, T> {
|
struct CsrMatrixSerializationData<'a, T: Clone> {
|
||||||
nrows: usize,
|
nrows: usize,
|
||||||
ncols: usize,
|
ncols: usize,
|
||||||
row_offsets: &'a [usize],
|
row_offsets: Cow<'a, [usize]>,
|
||||||
col_indices: &'a [usize],
|
col_indices: Cow<'a, [usize]>,
|
||||||
values: &'a [T],
|
values: Cow<'a, [T]>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
#[cfg(feature = "serde-serialize")]
|
||||||
impl<T> Serialize for CsrMatrix<T>
|
impl<T> Serialize for CsrMatrix<T>
|
||||||
where
|
where
|
||||||
T: Serialize,
|
T: Serialize + Clone,
|
||||||
{
|
{
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
where
|
where
|
||||||
|
@ -615,42 +616,31 @@ where
|
||||||
CsrMatrixSerializationData {
|
CsrMatrixSerializationData {
|
||||||
nrows: self.nrows(),
|
nrows: self.nrows(),
|
||||||
ncols: self.ncols(),
|
ncols: self.ncols(),
|
||||||
row_offsets: self.row_offsets(),
|
row_offsets: Cow::Borrowed(self.row_offsets()),
|
||||||
col_indices: self.col_indices(),
|
col_indices: Cow::Borrowed(self.col_indices()),
|
||||||
values: self.values(),
|
values: Cow::Borrowed(self.values()),
|
||||||
}
|
}
|
||||||
.serialize(serializer)
|
.serialize(serializer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct CsrMatrixDeserializationData<T> {
|
|
||||||
nrows: usize,
|
|
||||||
ncols: usize,
|
|
||||||
row_offsets: Vec<usize>,
|
|
||||||
col_indices: Vec<usize>,
|
|
||||||
values: Vec<T>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
#[cfg(feature = "serde-serialize")]
|
||||||
impl<'de, T> Deserialize<'de> for CsrMatrix<T>
|
impl<'de, T> Deserialize<'de> for CsrMatrix<T>
|
||||||
where
|
where
|
||||||
T: Deserialize<'de>,
|
T: Deserialize<'de> + Clone,
|
||||||
{
|
{
|
||||||
fn deserialize<D>(deserializer: D) -> Result<CsrMatrix<T>, D::Error>
|
fn deserialize<D>(deserializer: D) -> Result<CsrMatrix<T>, D::Error>
|
||||||
where
|
where
|
||||||
D: Deserializer<'de>,
|
D: Deserializer<'de>,
|
||||||
{
|
{
|
||||||
let de = CsrMatrixDeserializationData::deserialize(deserializer)?;
|
let de = CsrMatrixSerializationData::deserialize(deserializer)?;
|
||||||
CsrMatrix::try_from_csr_data(
|
CsrMatrix::try_from_csr_data(
|
||||||
de.nrows,
|
de.nrows,
|
||||||
de.ncols,
|
de.ncols,
|
||||||
de.row_offsets,
|
de.row_offsets.into(),
|
||||||
de.col_indices,
|
de.col_indices.into(),
|
||||||
de.values,
|
de.values.into(),
|
||||||
)
|
)
|
||||||
.map(|m| m.into())
|
|
||||||
.map_err(|e| de::Error::custom(e))
|
.map_err(|e| de::Error::custom(e))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
//! Sparsity patterns for CSR and CSC matrices.
|
//! Sparsity patterns for CSR and CSC matrices.
|
||||||
use crate::cs::transpose_cs;
|
use crate::cs::transpose_cs;
|
||||||
use crate::SparseFormatError;
|
use crate::SparseFormatError;
|
||||||
|
use std::borrow::Cow;
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
@ -293,12 +294,12 @@ pub enum SparsityPatternFormatError {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
#[cfg(feature = "serde-serialize")]
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
struct SparsityPatternSerializationData<'a> {
|
struct SparsityPatternSerializationData<'a> {
|
||||||
major_dim: usize,
|
major_dim: usize,
|
||||||
minor_dim: usize,
|
minor_dim: usize,
|
||||||
major_offsets: &'a [usize],
|
major_offsets: Cow<'a, [usize]>,
|
||||||
minor_indices: &'a [usize],
|
minor_indices: Cow<'a, [usize]>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
#[cfg(feature = "serde-serialize")]
|
||||||
|
@ -310,36 +311,26 @@ impl Serialize for SparsityPattern {
|
||||||
SparsityPatternSerializationData {
|
SparsityPatternSerializationData {
|
||||||
major_dim: self.major_dim(),
|
major_dim: self.major_dim(),
|
||||||
minor_dim: self.minor_dim(),
|
minor_dim: self.minor_dim(),
|
||||||
major_offsets: self.major_offsets(),
|
major_offsets: Cow::Borrowed(self.major_offsets()),
|
||||||
minor_indices: self.minor_indices(),
|
minor_indices: Cow::Borrowed(self.minor_indices()),
|
||||||
}
|
}
|
||||||
.serialize(serializer)
|
.serialize(serializer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct SparsityPatternDeserializationData {
|
|
||||||
major_dim: usize,
|
|
||||||
minor_dim: usize,
|
|
||||||
major_offsets: Vec<usize>,
|
|
||||||
minor_indices: Vec<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "serde-serialize")]
|
#[cfg(feature = "serde-serialize")]
|
||||||
impl<'de> Deserialize<'de> for SparsityPattern {
|
impl<'de> Deserialize<'de> for SparsityPattern {
|
||||||
fn deserialize<D>(deserializer: D) -> Result<SparsityPattern, D::Error>
|
fn deserialize<D>(deserializer: D) -> Result<SparsityPattern, D::Error>
|
||||||
where
|
where
|
||||||
D: Deserializer<'de>,
|
D: Deserializer<'de>,
|
||||||
{
|
{
|
||||||
let de = SparsityPatternDeserializationData::deserialize(deserializer)?;
|
let de = SparsityPatternSerializationData::deserialize(deserializer)?;
|
||||||
SparsityPattern::try_from_offsets_and_indices(
|
SparsityPattern::try_from_offsets_and_indices(
|
||||||
de.major_dim,
|
de.major_dim,
|
||||||
de.minor_dim,
|
de.minor_dim,
|
||||||
de.major_offsets,
|
de.major_offsets.into(),
|
||||||
de.minor_indices,
|
de.minor_indices.into(),
|
||||||
)
|
)
|
||||||
.map(|m| m.into())
|
|
||||||
.map_err(|e| de::Error::custom(e))
|
.map_err(|e| de::Error::custom(e))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue