use num::{One, Zero}; use simba::scalar::{ClosedAdd, ClosedMul}; use std::ops::{Add, Mul}; use crate::allocator::Allocator; use crate::constraint::{AreMultipliable, DimEq, ShapeConstraint}; use crate::sparse::{CsMatrix, CsStorage, CsStorageMut, CsVector}; use crate::storage::StorageMut; use crate::{Const, DefaultAllocator, Dim, OVector, Scalar, Vector}; impl> CsMatrix { fn scatter( &self, j: usize, beta: T, timestamps: &mut [usize], timestamp: usize, workspace: &mut [T], mut nz: usize, res: &mut CsMatrix, ) -> usize where T: ClosedAdd + ClosedMul, DefaultAllocator: Allocator, { for (i, val) in self.data.column_entries(j) { if timestamps[i] < timestamp { timestamps[i] = timestamp; res.data.i[nz] = i; nz += 1; workspace[i] = val * beta.inlined_clone(); } else { workspace[i] += val * beta.inlined_clone(); } } nz } } /* impl CsVector { pub fn axpy(&mut self, alpha: T, x: CsVector, beta: T) { // First, compute the number of non-zero entries. let mut nnzero = 0; // Allocate a size large enough. self.data.set_column_len(0, nnzero); // Fill with the axpy. let mut i = self.len(); let mut j = x.len(); let mut k = nnzero - 1; let mut rid1 = self.data.row_index(0, i - 1); let mut rid2 = x.data.row_index(0, j - 1); while k > 0 { if rid1 == rid2 { self.data.set_row_index(0, k, rid1); self[k] = alpha * x[j] + beta * self[k]; i -= 1; j -= 1; } else if rid1 < rid2 { self.data.set_row_index(0, k, rid1); self[k] = beta * self[i]; i -= 1; } else { self.data.set_row_index(0, k, rid2); self[k] = alpha * x[j]; j -= 1; } k -= 1; } } } */ impl> Vector { /// Perform a sparse axpy operation: `self = alpha * x + beta * self` operation. pub fn axpy_cs(&mut self, alpha: T, x: &CsVector, beta: T) where S2: CsStorage, ShapeConstraint: DimEq, { if beta.is_zero() { for i in 0..x.len() { unsafe { let k = x.data.row_index_unchecked(i); let y = self.vget_unchecked_mut(k); *y = alpha.inlined_clone() * x.data.get_value_unchecked(i).inlined_clone(); } } } else { // Needed to be sure even components not present on `x` are multiplied. *self *= beta.inlined_clone(); for i in 0..x.len() { unsafe { let k = x.data.row_index_unchecked(i); let y = self.vget_unchecked_mut(k); *y += alpha.inlined_clone() * x.data.get_value_unchecked(i).inlined_clone(); } } } } /* pub fn gemv_sparse(&mut self, alpha: T, a: &CsMatrix, x: &DVector, beta: T) where S2: CsStorage { let col2 = a.column(0); let val = unsafe { *x.vget_unchecked(0) }; self.axpy_sparse(alpha * val, &col2, beta); for j in 1..ncols2 { let col2 = a.column(j); let val = unsafe { *x.vget_unchecked(j) }; self.axpy_sparse(alpha * val, &col2, T::one()); } } */ } impl<'a, 'b, T, R1, R2, C1, C2, S1, S2> Mul<&'b CsMatrix> for &'a CsMatrix where T: Scalar + ClosedAdd + ClosedMul + Zero, R1: Dim, C1: Dim, R2: Dim, C2: Dim, S1: CsStorage, S2: CsStorage, ShapeConstraint: AreMultipliable, DefaultAllocator: Allocator + Allocator + Allocator, { type Output = CsMatrix; fn mul(self, rhs: &'b CsMatrix) -> Self::Output { let (nrows1, ncols1) = self.data.shape(); let (nrows2, ncols2) = rhs.data.shape(); assert_eq!( ncols1.value(), nrows2.value(), "Mismatched dimensions for matrix multiplication." ); let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len()); let mut workspace = OVector::::zeros_generic(nrows1, Const::<1>); let mut nz = 0; for j in 0..ncols2.value() { res.data.p[j] = nz; let new_size_bound = nz + nrows1.value(); res.data.i.resize(new_size_bound, 0); res.data.vals.resize(new_size_bound, T::zero()); for (i, beta) in rhs.data.column_entries(j) { for (k, val) in self.data.column_entries(i) { workspace[k] += val.inlined_clone() * beta.inlined_clone(); } } for (i, val) in workspace.as_mut_slice().iter_mut().enumerate() { if !val.is_zero() { res.data.i[nz] = i; res.data.vals[nz] = val.inlined_clone(); *val = T::zero(); nz += 1; } } } // NOTE: the following has a lower complexity, but is slower in many cases, likely because // of branching inside of the inner loop. // // let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len()); // let mut timestamps = OVector::zeros_generic(nrows1, Const::<)>; // let mut workspace = unsafe { OVector::new_uninitialized_generic(nrows1, Const::<)> }; // let mut nz = 0; // // for j in 0..ncols2.value() { // res.data.p[j] = nz; // let new_size_bound = nz + nrows1.value(); // res.data.i.resize(new_size_bound, 0); // res.data.vals.resize(new_size_bound, T::zero()); // // for (i, val) in rhs.data.column_entries(j) { // nz = self.scatter( // i, // val, // timestamps.as_mut_slice(), // j + 1, // workspace.as_mut_slice(), // nz, // &mut res, // ); // } // // // Keep the output sorted. // let range = res.data.p[j]..nz; // res.data.i[range.clone()].sort(); // // for p in range { // res.data.vals[p] = workspace[res.data.i[p]] // } // } res.data.i.truncate(nz); res.data.i.shrink_to_fit(); res.data.vals.truncate(nz); res.data.vals.shrink_to_fit(); res } } impl<'a, 'b, T, R1, R2, C1, C2, S1, S2> Add<&'b CsMatrix> for &'a CsMatrix where T: Scalar + ClosedAdd + ClosedMul + One, R1: Dim, C1: Dim, R2: Dim, C2: Dim, S1: CsStorage, S2: CsStorage, ShapeConstraint: DimEq + DimEq, DefaultAllocator: Allocator + Allocator + Allocator, { type Output = CsMatrix; fn add(self, rhs: &'b CsMatrix) -> Self::Output { let (nrows1, ncols1) = self.data.shape(); let (nrows2, ncols2) = rhs.data.shape(); assert_eq!( (nrows1.value(), ncols1.value()), (nrows2.value(), ncols2.value()), "Mismatched dimensions for matrix sum." ); let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len()); let mut timestamps = OVector::zeros_generic(nrows1, Const::<1>); let mut workspace = unsafe { crate::unimplemented_or_uninitialized_generic!(nrows1, Const::<1>) }; let mut nz = 0; for j in 0..ncols2.value() { res.data.p[j] = nz; nz = self.scatter( j, T::one(), timestamps.as_mut_slice(), j + 1, workspace.as_mut_slice(), nz, &mut res, ); nz = rhs.scatter( j, T::one(), timestamps.as_mut_slice(), j + 1, workspace.as_mut_slice(), nz, &mut res, ); // Keep the output sorted. let range = res.data.p[j]..nz; res.data.i[range.clone()].sort(); for p in range { res.data.vals[p] = workspace[res.data.i[p]].inlined_clone() } } res.data.i.truncate(nz); res.data.i.shrink_to_fit(); res.data.vals.truncate(nz); res.data.vals.shrink_to_fit(); res } } impl<'a, 'b, T, R, C, S> Mul for CsMatrix where T: Scalar + ClosedAdd + ClosedMul + Zero, R: Dim, C: Dim, S: CsStorageMut, { type Output = Self; fn mul(mut self, rhs: T) -> Self::Output { for e in self.values_mut() { *e *= rhs.inlined_clone() } self } }