Replace usage of Cow with generic type

This commit is contained in:
Fabian Löschner 2021-12-15 11:43:35 +01:00 committed by Fabian Loeschner
parent 49eb1bd778
commit 837ded932e
4 changed files with 45 additions and 53 deletions

View File

@ -4,8 +4,6 @@ use crate::SparseFormatError;
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
#[cfg(feature = "serde-serialize")]
use std::borrow::Cow;
/// A COO representation of a sparse matrix. /// A COO representation of a sparse matrix.
/// ///
@ -281,12 +279,12 @@ impl<T> CooMatrix<T> {
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct CooMatrixSerializationData<'a, T: Clone> { struct CooMatrixSerializationData<Indices, Values> {
nrows: usize, nrows: usize,
ncols: usize, ncols: usize,
row_indices: Cow<'a, [usize]>, row_indices: Indices,
col_indices: Cow<'a, [usize]>, col_indices: Indices,
values: Cow<'a, [T]>, values: Values,
} }
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
@ -298,12 +296,12 @@ where
where where
S: Serializer, S: Serializer,
{ {
CooMatrixSerializationData { CooMatrixSerializationData::<&[usize], &[T]> {
nrows: self.nrows(), nrows: self.nrows(),
ncols: self.ncols(), ncols: self.ncols(),
row_indices: self.row_indices().into(), row_indices: self.row_indices(),
col_indices: self.col_indices().into(), col_indices: self.col_indices(),
values: self.values().into(), values: self.values(),
} }
.serialize(serializer) .serialize(serializer)
} }
@ -318,13 +316,13 @@ where
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let de = CooMatrixSerializationData::deserialize(deserializer)?; let de = CooMatrixSerializationData::<Vec<usize>, Vec<T>>::deserialize(deserializer)?;
CooMatrix::try_from_triplets( CooMatrix::try_from_triplets(
de.nrows, de.nrows,
de.ncols, de.ncols,
de.row_indices.into(), de.row_indices,
de.col_indices.into(), de.col_indices,
de.values.into(), de.values,
) )
.map_err(|e| de::Error::custom(e)) .map_err(|e| de::Error::custom(e))
} }

View File

@ -12,8 +12,6 @@ use nalgebra::Scalar;
use num_traits::One; use num_traits::One;
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
#[cfg(feature = "serde-serialize")]
use std::borrow::Cow;
use std::slice::{Iter, IterMut}; use std::slice::{Iter, IterMut};
/// A CSC representation of a sparse matrix. /// A CSC representation of a sparse matrix.
@ -526,12 +524,12 @@ impl<T> CscMatrix<T> {
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct CscMatrixSerializationData<'a, T: Clone> { struct CscMatrixSerializationData<Indices, Values> {
nrows: usize, nrows: usize,
ncols: usize, ncols: usize,
col_offsets: Cow<'a, [usize]>, col_offsets: Indices,
row_indices: Cow<'a, [usize]>, row_indices: Indices,
values: Cow<'a, [T]>, values: Values,
} }
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
@ -543,12 +541,12 @@ where
where where
S: Serializer, S: Serializer,
{ {
CscMatrixSerializationData { CscMatrixSerializationData::<&[usize], &[T]> {
nrows: self.nrows(), nrows: self.nrows(),
ncols: self.ncols(), ncols: self.ncols(),
col_offsets: Cow::Borrowed(self.col_offsets()), col_offsets: self.col_offsets(),
row_indices: Cow::Borrowed(self.row_indices()), row_indices: self.row_indices(),
values: Cow::Borrowed(self.values()), values: self.values(),
} }
.serialize(serializer) .serialize(serializer)
} }
@ -563,13 +561,13 @@ where
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let de = CscMatrixSerializationData::deserialize(deserializer)?; let de = CscMatrixSerializationData::<Vec<usize>, Vec<T>>::deserialize(deserializer)?;
CscMatrix::try_from_csc_data( CscMatrix::try_from_csc_data(
de.nrows, de.nrows,
de.ncols, de.ncols,
de.col_offsets.into(), de.col_offsets,
de.row_indices.into(), de.row_indices,
de.values.into(), de.values,
) )
.map_err(|e| de::Error::custom(e)) .map_err(|e| de::Error::custom(e))
} }

View File

@ -11,8 +11,6 @@ use nalgebra::Scalar;
use num_traits::One; use num_traits::One;
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
#[cfg(feature = "serde-serialize")]
use std::borrow::Cow;
use std::iter::FromIterator; use std::iter::FromIterator;
use std::slice::{Iter, IterMut}; use std::slice::{Iter, IterMut};
@ -597,12 +595,12 @@ impl<T> CsrMatrix<T> {
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct CsrMatrixSerializationData<'a, T: Clone> { struct CsrMatrixSerializationData<Indices, Values> {
nrows: usize, nrows: usize,
ncols: usize, ncols: usize,
row_offsets: Cow<'a, [usize]>, row_offsets: Indices,
col_indices: Cow<'a, [usize]>, col_indices: Indices,
values: Cow<'a, [T]>, values: Values,
} }
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
@ -614,12 +612,12 @@ where
where where
S: Serializer, S: Serializer,
{ {
CsrMatrixSerializationData { CsrMatrixSerializationData::<&[usize], &[T]> {
nrows: self.nrows(), nrows: self.nrows(),
ncols: self.ncols(), ncols: self.ncols(),
row_offsets: Cow::Borrowed(self.row_offsets()), row_offsets: self.row_offsets(),
col_indices: Cow::Borrowed(self.col_indices()), col_indices: self.col_indices(),
values: Cow::Borrowed(self.values()), values: self.values(),
} }
.serialize(serializer) .serialize(serializer)
} }
@ -634,13 +632,13 @@ where
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let de = CsrMatrixSerializationData::deserialize(deserializer)?; let de = CsrMatrixSerializationData::<Vec<usize>, Vec<T>>::deserialize(deserializer)?;
CsrMatrix::try_from_csr_data( CsrMatrix::try_from_csr_data(
de.nrows, de.nrows,
de.ncols, de.ncols,
de.row_offsets.into(), de.row_offsets,
de.col_indices.into(), de.col_indices,
de.values.into(), de.values,
) )
.map_err(|e| de::Error::custom(e)) .map_err(|e| de::Error::custom(e))
} }

View File

@ -6,8 +6,6 @@ use std::fmt;
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
#[cfg(feature = "serde-serialize")]
use std::borrow::Cow;
/// A representation of the sparsity pattern of a CSR or CSC matrix. /// A representation of the sparsity pattern of a CSR or CSC matrix.
/// ///
@ -296,11 +294,11 @@ pub enum SparsityPatternFormatError {
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct SparsityPatternSerializationData<'a> { struct SparsityPatternSerializationData<Indices> {
major_dim: usize, major_dim: usize,
minor_dim: usize, minor_dim: usize,
major_offsets: Cow<'a, [usize]>, major_offsets: Indices,
minor_indices: Cow<'a, [usize]>, minor_indices: Indices,
} }
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
@ -309,11 +307,11 @@ impl Serialize for SparsityPattern {
where where
S: Serializer, S: Serializer,
{ {
SparsityPatternSerializationData { SparsityPatternSerializationData::<&[usize]> {
major_dim: self.major_dim(), major_dim: self.major_dim(),
minor_dim: self.minor_dim(), minor_dim: self.minor_dim(),
major_offsets: Cow::Borrowed(self.major_offsets()), major_offsets: self.major_offsets(),
minor_indices: Cow::Borrowed(self.minor_indices()), minor_indices: self.minor_indices(),
} }
.serialize(serializer) .serialize(serializer)
} }
@ -325,12 +323,12 @@ impl<'de> Deserialize<'de> for SparsityPattern {
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let de = SparsityPatternSerializationData::deserialize(deserializer)?; let de = SparsityPatternSerializationData::<Vec<usize>>::deserialize(deserializer)?;
SparsityPattern::try_from_offsets_and_indices( SparsityPattern::try_from_offsets_and_indices(
de.major_dim, de.major_dim,
de.minor_dim, de.minor_dim,
de.major_offsets.into(), de.major_offsets,
de.minor_indices.into(), de.minor_indices,
) )
.map_err(|e| de::Error::custom(e)) .map_err(|e| de::Error::custom(e))
} }