Improve CsMatrix multiplaction performances.
This commit is contained in:
parent
538e18b3e9
commit
383a18f083
|
@ -7,7 +7,7 @@ use std::slice;
|
||||||
|
|
||||||
use allocator::Allocator;
|
use allocator::Allocator;
|
||||||
use constraint::{AreMultipliable, DimEq, SameNumberOfRows, ShapeConstraint};
|
use constraint::{AreMultipliable, DimEq, SameNumberOfRows, ShapeConstraint};
|
||||||
use sparse::{CsMatrix, CsStorage, CsVector};
|
use sparse::{CsMatrix, CsStorage, CsStorageMut, CsVector};
|
||||||
use storage::{Storage, StorageMut};
|
use storage::{Storage, StorageMut};
|
||||||
use {DefaultAllocator, Dim, Matrix, MatrixMN, Real, Scalar, Vector, VectorN, U1};
|
use {DefaultAllocator, Dim, Matrix, MatrixMN, Real, Scalar, Vector, VectorN, U1};
|
||||||
|
|
||||||
|
@ -150,8 +150,7 @@ where
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len());
|
let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len());
|
||||||
let mut timestamps = VectorN::zeros_generic(nrows1, U1);
|
let mut workspace = VectorN::<N, R1>::zeros_generic(nrows1, U1);
|
||||||
let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) };
|
|
||||||
let mut nz = 0;
|
let mut nz = 0;
|
||||||
|
|
||||||
for j in 0..ncols2.value() {
|
for j in 0..ncols2.value() {
|
||||||
|
@ -160,24 +159,19 @@ where
|
||||||
res.data.i.resize(new_size_bound, 0);
|
res.data.i.resize(new_size_bound, 0);
|
||||||
res.data.vals.resize(new_size_bound, N::zero());
|
res.data.vals.resize(new_size_bound, N::zero());
|
||||||
|
|
||||||
for (i, val) in rhs.data.column_entries(j) {
|
for (i, beta) in rhs.data.column_entries(j) {
|
||||||
nz = self.scatter(
|
for (k, val) in self.data.column_entries(i) {
|
||||||
i,
|
workspace[k] += val * beta;
|
||||||
val,
|
}
|
||||||
timestamps.as_mut_slice(),
|
|
||||||
j + 1,
|
|
||||||
workspace.as_mut_slice(),
|
|
||||||
nz,
|
|
||||||
&mut res,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Keep the output sorted.
|
for (i, val) in workspace.as_mut_slice().iter_mut().enumerate() {
|
||||||
let range = res.data.p[j]..nz;
|
if !val.is_zero() {
|
||||||
res.data.i[range.clone()].sort();
|
res.data.i[nz] = i;
|
||||||
|
res.data.vals[nz] = *val;
|
||||||
for p in range {
|
*val = N::zero();
|
||||||
res.data.vals[p] = workspace[res.data.i[p]]
|
nz += 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,3 +251,21 @@ where
|
||||||
res
|
res
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b, N, R, C, S> Mul<N> for CsMatrix<N, R, C, S>
|
||||||
|
where
|
||||||
|
N: Scalar + ClosedAdd + ClosedMul + Zero,
|
||||||
|
R: Dim,
|
||||||
|
C: Dim,
|
||||||
|
S: CsStorageMut<N, R, C>,
|
||||||
|
{
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn mul(mut self, rhs: N) -> Self {
|
||||||
|
for e in self.values_mut() {
|
||||||
|
*e *= rhs
|
||||||
|
}
|
||||||
|
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue