use crate::csr::CsrMatrix; use crate::ops::{Op}; use crate::ops::serial::{OperationError}; use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut}; use num_traits::{Zero, One}; use std::borrow::Cow; use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc}; /// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`. pub fn spmm_csr_dense<'a, T>(beta: T, c: impl Into>, alpha: T, a: Op<&CsrMatrix>, b: Op>>) where T: Scalar + ClosedAdd + ClosedMul + Zero + One { let b = b.convert(); spmm_csr_dense_(beta, c.into(), alpha, a, b) } fn spmm_csr_dense_(beta: T, c: DMatrixSliceMut, alpha: T, a: Op<&CsrMatrix>, b: Op>) where T: Scalar + ClosedAdd + ClosedMul + Zero + One { assert_compatible_spmm_dims!(c, a, b); spmm_cs_dense(beta, c, alpha, a.map_same_op(|a| &a.cs), b) } /// Sparse matrix addition `C <- beta * C + alpha * op(A)`. /// /// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is /// returned. pub fn spadd_csr_prealloc(beta: T, c: &mut CsrMatrix, alpha: T, a: Op<&CsrMatrix>) -> Result<(), OperationError> where T: Scalar + ClosedAdd + ClosedMul + Zero + One { assert_compatible_spadd_dims!(c, a); spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs)) } /// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`. pub fn spmm_csr_prealloc( beta: T, c: &mut CsrMatrix, alpha: T, a: Op<&CsrMatrix>, b: Op<&CsrMatrix>) -> Result<(), OperationError> where T: Scalar + ClosedAdd + ClosedMul + Zero + One { assert_compatible_spmm_dims!(c, a, b); use Op::{NoOp, Transpose}; match (&a, &b) { (NoOp(ref a), NoOp(ref b)) => { spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs) }, _ => { // Currently we handle transposition by explicitly precomputing transposed matrices // and calling the operation again without transposition // TODO: At least use workspaces to allow control of allocations. Maybe // consider implementing certain patterns (like A^T * B) explicitly let a_ref: &CsrMatrix = a.inner_ref(); let b_ref: &CsrMatrix = b.inner_ref(); let (a, b) = { use Cow::*; match (&a, &b) { (NoOp(_), NoOp(_)) => unreachable!(), (Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)), (NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())), (Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose())) } }; spmm_csr_prealloc(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref())) } } }