Improve CsMatrix multiplaction performances.

This commit is contained in:
sebcrozet 2018-11-06 18:27:43 +01:00
parent 538e18b3e9
commit 383a18f083

View File

@ -7,7 +7,7 @@ use std::slice;
use allocator::Allocator;
use constraint::{AreMultipliable, DimEq, SameNumberOfRows, ShapeConstraint};
use sparse::{CsMatrix, CsStorage, CsVector};
use sparse::{CsMatrix, CsStorage, CsStorageMut, CsVector};
use storage::{Storage, StorageMut};
use {DefaultAllocator, Dim, Matrix, MatrixMN, Real, Scalar, Vector, VectorN, U1};
@ -150,8 +150,7 @@ where
);
let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len());
let mut timestamps = VectorN::zeros_generic(nrows1, U1);
let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) };
let mut workspace = VectorN::<N, R1>::zeros_generic(nrows1, U1);
let mut nz = 0;
for j in 0..ncols2.value() {
@ -160,24 +159,19 @@ where
res.data.i.resize(new_size_bound, 0);
res.data.vals.resize(new_size_bound, N::zero());
for (i, val) in rhs.data.column_entries(j) {
nz = self.scatter(
i,
val,
timestamps.as_mut_slice(),
j + 1,
workspace.as_mut_slice(),
nz,
&mut res,
);
for (i, beta) in rhs.data.column_entries(j) {
for (k, val) in self.data.column_entries(i) {
workspace[k] += val * beta;
}
}
// Keep the output sorted.
let range = res.data.p[j]..nz;
res.data.i[range.clone()].sort();
for p in range {
res.data.vals[p] = workspace[res.data.i[p]]
for (i, val) in workspace.as_mut_slice().iter_mut().enumerate() {
if !val.is_zero() {
res.data.i[nz] = i;
res.data.vals[nz] = *val;
*val = N::zero();
nz += 1;
}
}
}
@ -257,3 +251,21 @@ where
res
}
}
impl<'a, 'b, N, R, C, S> Mul<N> for CsMatrix<N, R, C, S>
where
N: Scalar + ClosedAdd + ClosedMul + Zero,
R: Dim,
C: Dim,
S: CsStorageMut<N, R, C>,
{
type Output = Self;
fn mul(mut self, rhs: N) -> Self {
for e in self.values_mut() {
*e *= rhs
}
self
}
}