Improve CsMatrix multiplaction performances.

This commit is contained in:
sebcrozet 2018-11-06 18:27:43 +01:00
parent 538e18b3e9
commit 383a18f083

View File

@ -7,7 +7,7 @@ use std::slice;
use allocator::Allocator; use allocator::Allocator;
use constraint::{AreMultipliable, DimEq, SameNumberOfRows, ShapeConstraint}; use constraint::{AreMultipliable, DimEq, SameNumberOfRows, ShapeConstraint};
use sparse::{CsMatrix, CsStorage, CsVector}; use sparse::{CsMatrix, CsStorage, CsStorageMut, CsVector};
use storage::{Storage, StorageMut}; use storage::{Storage, StorageMut};
use {DefaultAllocator, Dim, Matrix, MatrixMN, Real, Scalar, Vector, VectorN, U1}; use {DefaultAllocator, Dim, Matrix, MatrixMN, Real, Scalar, Vector, VectorN, U1};
@ -150,8 +150,7 @@ where
); );
let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len()); let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len());
let mut timestamps = VectorN::zeros_generic(nrows1, U1); let mut workspace = VectorN::<N, R1>::zeros_generic(nrows1, U1);
let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) };
let mut nz = 0; let mut nz = 0;
for j in 0..ncols2.value() { for j in 0..ncols2.value() {
@ -160,24 +159,19 @@ where
res.data.i.resize(new_size_bound, 0); res.data.i.resize(new_size_bound, 0);
res.data.vals.resize(new_size_bound, N::zero()); res.data.vals.resize(new_size_bound, N::zero());
for (i, val) in rhs.data.column_entries(j) { for (i, beta) in rhs.data.column_entries(j) {
nz = self.scatter( for (k, val) in self.data.column_entries(i) {
i, workspace[k] += val * beta;
val, }
timestamps.as_mut_slice(),
j + 1,
workspace.as_mut_slice(),
nz,
&mut res,
);
} }
// Keep the output sorted. for (i, val) in workspace.as_mut_slice().iter_mut().enumerate() {
let range = res.data.p[j]..nz; if !val.is_zero() {
res.data.i[range.clone()].sort(); res.data.i[nz] = i;
res.data.vals[nz] = *val;
for p in range { *val = N::zero();
res.data.vals[p] = workspace[res.data.i[p]] nz += 1;
}
} }
} }
@ -257,3 +251,21 @@ where
res res
} }
} }
impl<'a, 'b, N, R, C, S> Mul<N> for CsMatrix<N, R, C, S>
where
N: Scalar + ClosedAdd + ClosedMul + Zero,
R: Dim,
C: Dim,
S: CsStorageMut<N, R, C>,
{
type Output = Self;
fn mul(mut self, rhs: N) -> Self {
for e in self.values_mut() {
*e *= rhs
}
self
}
}