Move serialization helper structs into trait impls

This commit is contained in:
Fabian Loeschner 2021-11-08 11:10:58 +01:00
parent f9aca24b15
commit 18b694dad2

View File

@ -599,15 +599,6 @@ impl<T> CsrMatrix<T> {
} }
} }
#[cfg_attr(feature = "serde-serialize", derive(Serialize))]
struct CsrMatrixSerializationHelper<'a, T> {
nrows: usize,
ncols: usize,
row_offsets: &'a [usize],
col_indices: &'a [usize],
values: &'a [T],
}
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
impl<T> Serialize for CsrMatrix<T> impl<T> Serialize for CsrMatrix<T>
where where
@ -617,7 +608,16 @@ where
where where
S: Serializer, S: Serializer,
{ {
CsrMatrixSerializationHelper { #[derive(Serialize)]
struct CsrMatrixSerializationData<'a, T> {
nrows: usize,
ncols: usize,
row_offsets: &'a [usize],
col_indices: &'a [usize],
values: &'a [T],
}
CsrMatrixSerializationData {
nrows: self.nrows(), nrows: self.nrows(),
ncols: self.ncols(), ncols: self.ncols(),
row_offsets: self.row_offsets(), row_offsets: self.row_offsets(),
@ -628,15 +628,6 @@ where
} }
} }
#[cfg_attr(feature = "serde-serialize", derive(Deserialize))]
struct CsrMatrixDeserializationHelper<T> {
nrows: usize,
ncols: usize,
row_offsets: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
}
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
impl<'de, T> Deserialize<'de> for CsrMatrix<T> impl<'de, T> Deserialize<'de> for CsrMatrix<T>
where where
@ -646,14 +637,17 @@ where
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let CsrMatrixDeserializationHelper { #[derive(Deserialize)]
nrows, struct CsrMatrixDeserializationData<T> {
ncols, nrows: usize,
row_offsets, ncols: usize,
col_indices, row_offsets: Vec<usize>,
values, col_indices: Vec<usize>,
} = CsrMatrixDeserializationHelper::deserialize(deserializer)?; values: Vec<T>,
CsrMatrix::try_from_csr_data(nrows, ncols, row_offsets, col_indices, values) }
let de = CsrMatrixDeserializationData::deserialize(deserializer)?;
CsrMatrix::try_from_csr_data(de.nrows, de.ncols, de.row_offsets, de.col_indices, de.values)
.map(|m| m.into()) .map(|m| m.into())
// TODO: More specific error // TODO: More specific error
.map_err(|_e| de::Error::invalid_value(de::Unexpected::Other("invalid CSR matrix"), &"a valid CSR matrix")) .map_err(|_e| de::Error::invalid_value(de::Unexpected::Other("invalid CSR matrix"), &"a valid CSR matrix"))