Improve CsMatrix multiplaction performances.
This commit is contained in:
parent
538e18b3e9
commit
383a18f083
@ -7,7 +7,7 @@ use std::slice;
|
||||
|
||||
use allocator::Allocator;
|
||||
use constraint::{AreMultipliable, DimEq, SameNumberOfRows, ShapeConstraint};
|
||||
use sparse::{CsMatrix, CsStorage, CsVector};
|
||||
use sparse::{CsMatrix, CsStorage, CsStorageMut, CsVector};
|
||||
use storage::{Storage, StorageMut};
|
||||
use {DefaultAllocator, Dim, Matrix, MatrixMN, Real, Scalar, Vector, VectorN, U1};
|
||||
|
||||
@ -150,8 +150,7 @@ where
|
||||
);
|
||||
|
||||
let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len());
|
||||
let mut timestamps = VectorN::zeros_generic(nrows1, U1);
|
||||
let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) };
|
||||
let mut workspace = VectorN::<N, R1>::zeros_generic(nrows1, U1);
|
||||
let mut nz = 0;
|
||||
|
||||
for j in 0..ncols2.value() {
|
||||
@ -160,24 +159,19 @@ where
|
||||
res.data.i.resize(new_size_bound, 0);
|
||||
res.data.vals.resize(new_size_bound, N::zero());
|
||||
|
||||
for (i, val) in rhs.data.column_entries(j) {
|
||||
nz = self.scatter(
|
||||
i,
|
||||
val,
|
||||
timestamps.as_mut_slice(),
|
||||
j + 1,
|
||||
workspace.as_mut_slice(),
|
||||
nz,
|
||||
&mut res,
|
||||
);
|
||||
for (i, beta) in rhs.data.column_entries(j) {
|
||||
for (k, val) in self.data.column_entries(i) {
|
||||
workspace[k] += val * beta;
|
||||
}
|
||||
}
|
||||
|
||||
// Keep the output sorted.
|
||||
let range = res.data.p[j]..nz;
|
||||
res.data.i[range.clone()].sort();
|
||||
|
||||
for p in range {
|
||||
res.data.vals[p] = workspace[res.data.i[p]]
|
||||
for (i, val) in workspace.as_mut_slice().iter_mut().enumerate() {
|
||||
if !val.is_zero() {
|
||||
res.data.i[nz] = i;
|
||||
res.data.vals[nz] = *val;
|
||||
*val = N::zero();
|
||||
nz += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -257,3 +251,21 @@ where
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b, N, R, C, S> Mul<N> for CsMatrix<N, R, C, S>
|
||||
where
|
||||
N: Scalar + ClosedAdd + ClosedMul + Zero,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: CsStorageMut<N, R, C>,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn mul(mut self, rhs: N) -> Self {
|
||||
for e in self.values_mut() {
|
||||
*e *= rhs
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user