nalgebra/src/sparse/cs_matrix_ops.rs
2021-04-11 13:57:54 +02:00

306 lines
9.5 KiB
Rust
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use num::{One, Zero};
use simba::scalar::{ClosedAdd, ClosedMul};
use std::ops::{Add, Mul};
use crate::allocator::Allocator;
use crate::constraint::{AreMultipliable, DimEq, ShapeConstraint};
use crate::sparse::{CsMatrix, CsStorage, CsStorageMut, CsVector};
use crate::storage::StorageMut;
use crate::{Const, DefaultAllocator, Dim, OVector, Scalar, Vector};
impl<T: Scalar, R: Dim, C: Dim, S: CsStorage<T, R, C>> CsMatrix<T, R, C, S> {
fn scatter<R2: Dim, C2: Dim>(
&self,
j: usize,
beta: T,
timestamps: &mut [usize],
timestamp: usize,
workspace: &mut [T],
mut nz: usize,
res: &mut CsMatrix<T, R2, C2>,
) -> usize
where
T: ClosedAdd + ClosedMul,
DefaultAllocator: Allocator<usize, C2>,
{
for (i, val) in self.data.column_entries(j) {
if timestamps[i] < timestamp {
timestamps[i] = timestamp;
res.data.i[nz] = i;
nz += 1;
workspace[i] = val * beta.inlined_clone();
} else {
workspace[i] += val * beta.inlined_clone();
}
}
nz
}
}
/*
impl<T: Scalar, R, S> CsVector<T, R, S> {
pub fn axpy(&mut self, alpha: T, x: CsVector<T, R, S>, beta: T) {
// First, compute the number of non-zero entries.
let mut nnzero = 0;
// Allocate a size large enough.
self.data.set_column_len(0, nnzero);
// Fill with the axpy.
let mut i = self.len();
let mut j = x.len();
let mut k = nnzero - 1;
let mut rid1 = self.data.row_index(0, i - 1);
let mut rid2 = x.data.row_index(0, j - 1);
while k > 0 {
if rid1 == rid2 {
self.data.set_row_index(0, k, rid1);
self[k] = alpha * x[j] + beta * self[k];
i -= 1;
j -= 1;
} else if rid1 < rid2 {
self.data.set_row_index(0, k, rid1);
self[k] = beta * self[i];
i -= 1;
} else {
self.data.set_row_index(0, k, rid2);
self[k] = alpha * x[j];
j -= 1;
}
k -= 1;
}
}
}
*/
impl<T: Scalar + Zero + ClosedAdd + ClosedMul, D: Dim, S: StorageMut<T, D>> Vector<T, D, S> {
/// Perform a sparse axpy operation: `self = alpha * x + beta * self` operation.
pub fn axpy_cs<D2: Dim, S2>(&mut self, alpha: T, x: &CsVector<T, D2, S2>, beta: T)
where
S2: CsStorage<T, D2>,
ShapeConstraint: DimEq<D, D2>,
{
if beta.is_zero() {
for i in 0..x.len() {
unsafe {
let k = x.data.row_index_unchecked(i);
let y = self.vget_unchecked_mut(k);
*y = alpha.inlined_clone() * x.data.get_value_unchecked(i).inlined_clone();
}
}
} else {
// Needed to be sure even components not present on `x` are multiplied.
*self *= beta.inlined_clone();
for i in 0..x.len() {
unsafe {
let k = x.data.row_index_unchecked(i);
let y = self.vget_unchecked_mut(k);
*y += alpha.inlined_clone() * x.data.get_value_unchecked(i).inlined_clone();
}
}
}
}
/*
pub fn gemv_sparse<R2: Dim, C2: Dim, S2>(&mut self, alpha: T, a: &CsMatrix<T, R2, C2, S2>, x: &DVector<T>, beta: T)
where
S2: CsStorage<T, R2, C2> {
let col2 = a.column(0);
let val = unsafe { *x.vget_unchecked(0) };
self.axpy_sparse(alpha * val, &col2, beta);
for j in 1..ncols2 {
let col2 = a.column(j);
let val = unsafe { *x.vget_unchecked(j) };
self.axpy_sparse(alpha * val, &col2, T::one());
}
}
*/
}
impl<'a, 'b, T, R1, R2, C1, C2, S1, S2> Mul<&'b CsMatrix<T, R2, C2, S2>>
for &'a CsMatrix<T, R1, C1, S1>
where
T: Scalar + ClosedAdd + ClosedMul + Zero,
R1: Dim,
C1: Dim,
R2: Dim,
C2: Dim,
S1: CsStorage<T, R1, C1>,
S2: CsStorage<T, R2, C2>,
ShapeConstraint: AreMultipliable<R1, C1, R2, C2>,
DefaultAllocator: Allocator<usize, C2> + Allocator<usize, R1> + Allocator<T, R1>,
{
type Output = CsMatrix<T, R1, C2>;
fn mul(self, rhs: &'b CsMatrix<T, R2, C2, S2>) -> Self::Output {
let (nrows1, ncols1) = self.data.shape();
let (nrows2, ncols2) = rhs.data.shape();
assert_eq!(
ncols1.value(),
nrows2.value(),
"Mismatched dimensions for matrix multiplication."
);
let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len());
let mut workspace = OVector::<T, R1>::zeros_generic(nrows1, Const::<1>);
let mut nz = 0;
for j in 0..ncols2.value() {
res.data.p[j] = nz;
let new_size_bound = nz + nrows1.value();
res.data.i.resize(new_size_bound, 0);
res.data.vals.resize(new_size_bound, T::zero());
for (i, beta) in rhs.data.column_entries(j) {
for (k, val) in self.data.column_entries(i) {
workspace[k] += val.inlined_clone() * beta.inlined_clone();
}
}
for (i, val) in workspace.as_mut_slice().iter_mut().enumerate() {
if !val.is_zero() {
res.data.i[nz] = i;
res.data.vals[nz] = val.inlined_clone();
*val = T::zero();
nz += 1;
}
}
}
// NOTE: the following has a lower complexity, but is slower in many cases, likely because
// of branching inside of the inner loop.
//
// let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len());
// let mut timestamps = OVector::zeros_generic(nrows1, Const::<)>;
// let mut workspace = unsafe { OVector::new_uninitialized_generic(nrows1, Const::<)> };
// let mut nz = 0;
//
// for j in 0..ncols2.value() {
// res.data.p[j] = nz;
// let new_size_bound = nz + nrows1.value();
// res.data.i.resize(new_size_bound, 0);
// res.data.vals.resize(new_size_bound, T::zero());
//
// for (i, val) in rhs.data.column_entries(j) {
// nz = self.scatter(
// i,
// val,
// timestamps.as_mut_slice(),
// j + 1,
// workspace.as_mut_slice(),
// nz,
// &mut res,
// );
// }
//
// // Keep the output sorted.
// let range = res.data.p[j]..nz;
// res.data.i[range.clone()].sort();
//
// for p in range {
// res.data.vals[p] = workspace[res.data.i[p]]
// }
// }
res.data.i.truncate(nz);
res.data.i.shrink_to_fit();
res.data.vals.truncate(nz);
res.data.vals.shrink_to_fit();
res
}
}
impl<'a, 'b, T, R1, R2, C1, C2, S1, S2> Add<&'b CsMatrix<T, R2, C2, S2>>
for &'a CsMatrix<T, R1, C1, S1>
where
T: Scalar + ClosedAdd + ClosedMul + One,
R1: Dim,
C1: Dim,
R2: Dim,
C2: Dim,
S1: CsStorage<T, R1, C1>,
S2: CsStorage<T, R2, C2>,
ShapeConstraint: DimEq<R1, R2> + DimEq<C1, C2>,
DefaultAllocator: Allocator<usize, C2> + Allocator<usize, R1> + Allocator<T, R1>,
{
type Output = CsMatrix<T, R1, C2>;
fn add(self, rhs: &'b CsMatrix<T, R2, C2, S2>) -> Self::Output {
let (nrows1, ncols1) = self.data.shape();
let (nrows2, ncols2) = rhs.data.shape();
assert_eq!(
(nrows1.value(), ncols1.value()),
(nrows2.value(), ncols2.value()),
"Mismatched dimensions for matrix sum."
);
let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len());
let mut timestamps = OVector::zeros_generic(nrows1, Const::<1>);
let mut workspace =
unsafe { crate::unimplemented_or_uninitialized_generic!(nrows1, Const::<1>) };
let mut nz = 0;
for j in 0..ncols2.value() {
res.data.p[j] = nz;
nz = self.scatter(
j,
T::one(),
timestamps.as_mut_slice(),
j + 1,
workspace.as_mut_slice(),
nz,
&mut res,
);
nz = rhs.scatter(
j,
T::one(),
timestamps.as_mut_slice(),
j + 1,
workspace.as_mut_slice(),
nz,
&mut res,
);
// Keep the output sorted.
let range = res.data.p[j]..nz;
res.data.i[range.clone()].sort();
for p in range {
res.data.vals[p] = workspace[res.data.i[p]].inlined_clone()
}
}
res.data.i.truncate(nz);
res.data.i.shrink_to_fit();
res.data.vals.truncate(nz);
res.data.vals.shrink_to_fit();
res
}
}
impl<'a, 'b, T, R, C, S> Mul<T> for CsMatrix<T, R, C, S>
where
T: Scalar + ClosedAdd + ClosedMul + Zero,
R: Dim,
C: Dim,
S: CsStorageMut<T, R, C>,
{
type Output = Self;
fn mul(mut self, rhs: T) -> Self::Output {
for e in self.values_mut() {
*e *= rhs.inlined_clone()
}
self
}
}