2020-12-02 23:56:22 +08:00
|
|
|
use crate::csr::CsrMatrix;
|
2020-12-03 00:04:19 +08:00
|
|
|
use crate::ops::{Transpose};
|
|
|
|
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
|
2020-12-02 23:56:22 +08:00
|
|
|
use num_traits::{Zero, One};
|
|
|
|
|
2020-12-04 21:13:07 +08:00
|
|
|
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * trans(A) * trans(B)`.
|
2020-12-02 23:56:22 +08:00
|
|
|
pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>,
|
|
|
|
beta: T,
|
|
|
|
alpha: T,
|
2020-12-03 00:04:19 +08:00
|
|
|
trans_a: Transpose,
|
2020-12-02 23:56:22 +08:00
|
|
|
a: &CsrMatrix<T>,
|
2020-12-03 00:04:19 +08:00
|
|
|
trans_b: Transpose,
|
2020-12-02 23:56:22 +08:00
|
|
|
b: impl Into<DMatrixSlice<'a, T>>)
|
|
|
|
where
|
|
|
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
|
|
|
{
|
|
|
|
spmm_csr_dense_(c.into(), beta, alpha, trans_a, a, trans_b, b.into())
|
|
|
|
}
|
|
|
|
|
|
|
|
fn spmm_csr_dense_<T>(mut c: DMatrixSliceMut<T>,
|
|
|
|
beta: T,
|
|
|
|
alpha: T,
|
2020-12-03 00:04:19 +08:00
|
|
|
trans_a: Transpose,
|
2020-12-02 23:56:22 +08:00
|
|
|
a: &CsrMatrix<T>,
|
2020-12-03 00:04:19 +08:00
|
|
|
trans_b: Transpose,
|
2020-12-02 23:56:22 +08:00
|
|
|
b: DMatrixSlice<T>)
|
|
|
|
where
|
|
|
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
|
|
|
{
|
|
|
|
assert_compatible_spmm_dims!(c, a, b, trans_a, trans_b);
|
|
|
|
|
2020-12-03 00:04:19 +08:00
|
|
|
if trans_a.to_bool() {
|
2020-12-02 23:56:22 +08:00
|
|
|
// In this case, we have to pre-multiply C by beta
|
|
|
|
c *= beta;
|
|
|
|
|
|
|
|
for k in 0..a.nrows() {
|
|
|
|
let a_row_k = a.row(k);
|
|
|
|
for (&i, a_ki) in a_row_k.col_indices().iter().zip(a_row_k.values()) {
|
|
|
|
let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone();
|
|
|
|
let mut c_row_i = c.row_mut(i);
|
2020-12-03 00:04:19 +08:00
|
|
|
if trans_b.to_bool() {
|
2020-12-02 23:56:22 +08:00
|
|
|
let b_col_k = b.column(k);
|
|
|
|
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
|
|
|
|
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
let b_row_k = b.row(k);
|
|
|
|
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
|
|
|
|
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
for j in 0..c.ncols() {
|
|
|
|
let mut c_col_j = c.column_mut(j);
|
|
|
|
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.row_iter()) {
|
|
|
|
let mut dot_ij = T::zero();
|
|
|
|
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
|
|
|
let b_contrib =
|
2020-12-03 00:04:19 +08:00
|
|
|
if trans_b.to_bool() { b.index((j, k)) } else { b.index((k, j)) };
|
2020-12-02 23:56:22 +08:00
|
|
|
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
|
|
|
|
}
|
|
|
|
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|