Move serialization helper structs into trait impls

This commit is contained in:
Fabian Loeschner 2021-11-08 11:10:58 +01:00
parent f9aca24b15
commit 18b694dad2

View File

@ -599,15 +599,6 @@ impl<T> CsrMatrix<T> {
}
}
#[cfg_attr(feature = "serde-serialize", derive(Serialize))]
struct CsrMatrixSerializationHelper<'a, T> {
nrows: usize,
ncols: usize,
row_offsets: &'a [usize],
col_indices: &'a [usize],
values: &'a [T],
}
#[cfg(feature = "serde-serialize")]
impl<T> Serialize for CsrMatrix<T>
where
@ -617,7 +608,16 @@ where
where
S: Serializer,
{
CsrMatrixSerializationHelper {
#[derive(Serialize)]
struct CsrMatrixSerializationData<'a, T> {
nrows: usize,
ncols: usize,
row_offsets: &'a [usize],
col_indices: &'a [usize],
values: &'a [T],
}
CsrMatrixSerializationData {
nrows: self.nrows(),
ncols: self.ncols(),
row_offsets: self.row_offsets(),
@ -628,15 +628,6 @@ where
}
}
#[cfg_attr(feature = "serde-serialize", derive(Deserialize))]
struct CsrMatrixDeserializationHelper<T> {
nrows: usize,
ncols: usize,
row_offsets: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
}
#[cfg(feature = "serde-serialize")]
impl<'de, T> Deserialize<'de> for CsrMatrix<T>
where
@ -646,14 +637,17 @@ where
where
D: Deserializer<'de>,
{
let CsrMatrixDeserializationHelper {
nrows,
ncols,
row_offsets,
col_indices,
values,
} = CsrMatrixDeserializationHelper::deserialize(deserializer)?;
CsrMatrix::try_from_csr_data(nrows, ncols, row_offsets, col_indices, values)
#[derive(Deserialize)]
struct CsrMatrixDeserializationData<T> {
nrows: usize,
ncols: usize,
row_offsets: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
}
let de = CsrMatrixDeserializationData::deserialize(deserializer)?;
CsrMatrix::try_from_csr_data(de.nrows, de.ncols, de.row_offsets, de.col_indices, de.values)
.map(|m| m.into())
// TODO: More specific error
.map_err(|_e| de::Error::invalid_value(de::Unexpected::Other("invalid CSR matrix"), &"a valid CSR matrix"))