From 383a18f08362d02ebd8a89bbf18cede29c7418c0 Mon Sep 17 00:00:00 2001 From: sebcrozet Date: Tue, 6 Nov 2018 18:27:43 +0100 Subject: [PATCH] Improve CsMatrix multiplaction performances. --- src/sparse/cs_matrix_ops.rs | 50 +++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/src/sparse/cs_matrix_ops.rs b/src/sparse/cs_matrix_ops.rs index 34a68f9b..bf1dccdb 100644 --- a/src/sparse/cs_matrix_ops.rs +++ b/src/sparse/cs_matrix_ops.rs @@ -7,7 +7,7 @@ use std::slice; use allocator::Allocator; use constraint::{AreMultipliable, DimEq, SameNumberOfRows, ShapeConstraint}; -use sparse::{CsMatrix, CsStorage, CsVector}; +use sparse::{CsMatrix, CsStorage, CsStorageMut, CsVector}; use storage::{Storage, StorageMut}; use {DefaultAllocator, Dim, Matrix, MatrixMN, Real, Scalar, Vector, VectorN, U1}; @@ -150,8 +150,7 @@ where ); let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len()); - let mut timestamps = VectorN::zeros_generic(nrows1, U1); - let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) }; + let mut workspace = VectorN::::zeros_generic(nrows1, U1); let mut nz = 0; for j in 0..ncols2.value() { @@ -160,24 +159,19 @@ where res.data.i.resize(new_size_bound, 0); res.data.vals.resize(new_size_bound, N::zero()); - for (i, val) in rhs.data.column_entries(j) { - nz = self.scatter( - i, - val, - timestamps.as_mut_slice(), - j + 1, - workspace.as_mut_slice(), - nz, - &mut res, - ); + for (i, beta) in rhs.data.column_entries(j) { + for (k, val) in self.data.column_entries(i) { + workspace[k] += val * beta; + } } - // Keep the output sorted. - let range = res.data.p[j]..nz; - res.data.i[range.clone()].sort(); - - for p in range { - res.data.vals[p] = workspace[res.data.i[p]] + for (i, val) in workspace.as_mut_slice().iter_mut().enumerate() { + if !val.is_zero() { + res.data.i[nz] = i; + res.data.vals[nz] = *val; + *val = N::zero(); + nz += 1; + } } } @@ -257,3 +251,21 @@ where res } } + +impl<'a, 'b, N, R, C, S> Mul for CsMatrix +where + N: Scalar + ClosedAdd + ClosedMul + Zero, + R: Dim, + C: Dim, + S: CsStorageMut, +{ + type Output = Self; + + fn mul(mut self, rhs: N) -> Self { + for e in self.values_mut() { + *e *= rhs + } + + self + } +}