2020-12-02 23:56:22 +08:00
|
|
|
use crate::csr::CsrMatrix;
|
2020-12-21 22:09:29 +08:00
|
|
|
use crate::ops::{Op};
|
2020-12-30 23:09:46 +08:00
|
|
|
use crate::ops::serial::{OperationError};
|
2020-12-03 00:04:19 +08:00
|
|
|
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
|
2020-12-02 23:56:22 +08:00
|
|
|
use num_traits::{Zero, One};
|
2020-12-14 23:55:06 +08:00
|
|
|
use std::borrow::Cow;
|
2020-12-30 23:09:46 +08:00
|
|
|
use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc};
|
2020-12-02 23:56:22 +08:00
|
|
|
|
2020-12-21 22:09:29 +08:00
|
|
|
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
|
2020-12-21 22:42:32 +08:00
|
|
|
pub fn spmm_csr_dense<'a, T>(beta: T,
|
|
|
|
c: impl Into<DMatrixSliceMut<'a, T>>,
|
2020-12-02 23:56:22 +08:00
|
|
|
alpha: T,
|
2020-12-21 22:09:29 +08:00
|
|
|
a: Op<&CsrMatrix<T>>,
|
|
|
|
b: Op<impl Into<DMatrixSlice<'a, T>>>)
|
2020-12-02 23:56:22 +08:00
|
|
|
where
|
|
|
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
|
|
|
{
|
2020-12-21 22:09:29 +08:00
|
|
|
let b = b.convert();
|
2020-12-21 22:42:32 +08:00
|
|
|
spmm_csr_dense_(beta, c.into(), alpha, a, b)
|
2020-12-02 23:56:22 +08:00
|
|
|
}
|
|
|
|
|
2020-12-21 22:42:32 +08:00
|
|
|
fn spmm_csr_dense_<T>(beta: T,
|
2020-12-30 23:09:46 +08:00
|
|
|
c: DMatrixSliceMut<T>,
|
2020-12-02 23:56:22 +08:00
|
|
|
alpha: T,
|
2020-12-21 22:09:29 +08:00
|
|
|
a: Op<&CsrMatrix<T>>,
|
|
|
|
b: Op<DMatrixSlice<T>>)
|
2020-12-02 23:56:22 +08:00
|
|
|
where
|
|
|
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
|
|
|
{
|
2020-12-21 22:09:29 +08:00
|
|
|
assert_compatible_spmm_dims!(c, a, b);
|
2020-12-30 23:09:46 +08:00
|
|
|
spmm_cs_dense(beta, c, alpha, a.map_same_op(|a| &a.cs), b)
|
2020-12-10 20:30:37 +08:00
|
|
|
}
|
|
|
|
|
2020-12-21 22:09:29 +08:00
|
|
|
/// Sparse matrix addition `C <- beta * C + alpha * op(A)`.
|
2020-12-10 20:30:37 +08:00
|
|
|
///
|
|
|
|
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
|
|
|
|
/// returned.
|
2020-12-21 23:05:38 +08:00
|
|
|
pub fn spadd_csr_prealloc<T>(beta: T,
|
|
|
|
c: &mut CsrMatrix<T>,
|
|
|
|
alpha: T,
|
|
|
|
a: Op<&CsrMatrix<T>>)
|
|
|
|
-> Result<(), OperationError>
|
2020-12-10 20:30:37 +08:00
|
|
|
where
|
|
|
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
|
|
|
{
|
2020-12-21 22:09:29 +08:00
|
|
|
assert_compatible_spadd_dims!(c, a);
|
2020-12-30 23:09:46 +08:00
|
|
|
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
|
2020-12-16 21:06:12 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
|
2020-12-21 23:05:38 +08:00
|
|
|
pub fn spmm_csr_prealloc<T>(
|
2020-12-16 21:06:12 +08:00
|
|
|
beta: T,
|
2020-12-21 22:42:32 +08:00
|
|
|
c: &mut CsrMatrix<T>,
|
2020-12-16 21:06:12 +08:00
|
|
|
alpha: T,
|
2020-12-21 22:09:29 +08:00
|
|
|
a: Op<&CsrMatrix<T>>,
|
|
|
|
b: Op<&CsrMatrix<T>>)
|
2020-12-16 21:06:12 +08:00
|
|
|
-> Result<(), OperationError>
|
|
|
|
where
|
|
|
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
|
|
|
{
|
2020-12-21 22:09:29 +08:00
|
|
|
assert_compatible_spmm_dims!(c, a, b);
|
2020-12-16 21:06:12 +08:00
|
|
|
|
2020-12-21 22:09:29 +08:00
|
|
|
use Op::{NoOp, Transpose};
|
2020-12-16 21:06:12 +08:00
|
|
|
|
2020-12-21 22:09:29 +08:00
|
|
|
match (&a, &b) {
|
|
|
|
(NoOp(ref a), NoOp(ref b)) => {
|
2020-12-30 23:09:46 +08:00
|
|
|
spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs)
|
2020-12-21 22:09:29 +08:00
|
|
|
},
|
|
|
|
_ => {
|
|
|
|
// Currently we handle transposition by explicitly precomputing transposed matrices
|
|
|
|
// and calling the operation again without transposition
|
|
|
|
// TODO: At least use workspaces to allow control of allocations. Maybe
|
|
|
|
// consider implementing certain patterns (like A^T * B) explicitly
|
|
|
|
let a_ref: &CsrMatrix<T> = a.inner_ref();
|
|
|
|
let b_ref: &CsrMatrix<T> = b.inner_ref();
|
|
|
|
let (a, b) = {
|
|
|
|
use Cow::*;
|
|
|
|
match (&a, &b) {
|
|
|
|
(NoOp(_), NoOp(_)) => unreachable!(),
|
|
|
|
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
|
|
|
|
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
|
|
|
|
(Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose()))
|
|
|
|
}
|
|
|
|
};
|
2020-12-16 21:06:12 +08:00
|
|
|
|
2020-12-21 23:05:38 +08:00
|
|
|
spmm_csr_prealloc(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
|
2020-12-21 22:09:29 +08:00
|
|
|
}
|
2020-12-16 21:06:12 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|