use crate::csr::CsrMatrix; use std::ops::Add; use crate::ops::serial::{spadd_csr, spadd_build_pattern}; use nalgebra::{ClosedAdd, ClosedMul, Scalar}; use num_traits::{Zero, One}; use std::sync::Arc; use crate::ops::Transpose; use crate::pattern::SparsityPattern; impl<'a, T> Add<&'a CsrMatrix> for &'a CsrMatrix where // TODO: Consider introducing wrapper trait for these things? It's technically a "Ring", // I guess... T: Scalar + ClosedAdd + ClosedMul + Zero + One { type Output = CsrMatrix; fn add(self, rhs: &'a CsrMatrix) -> Self::Output { let mut pattern = SparsityPattern::new(self.nrows(), self.ncols()); spadd_build_pattern(&mut pattern, self.pattern(), rhs.pattern()); let values = vec![T::zero(); pattern.nnz()]; // We are giving data that is valid by definition, so it is safe to unwrap below let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values) .unwrap(); spadd_csr(&mut result, T::zero(), T::one(), Transpose(false), &self).unwrap(); spadd_csr(&mut result, T::one(), T::one(), Transpose(false), &rhs).unwrap(); result } } impl<'a, T> Add<&'a CsrMatrix> for CsrMatrix where T: Scalar + ClosedAdd + ClosedMul + Zero + One { type Output = CsrMatrix; fn add(mut self, rhs: &'a CsrMatrix) -> Self::Output { if Arc::ptr_eq(self.pattern(), rhs.pattern()) { spadd_csr(&mut self, T::one(), T::one(), Transpose(false), &rhs).unwrap(); self } else { &self + rhs } } } impl<'a, T> Add> for &'a CsrMatrix where T: Scalar + ClosedAdd + ClosedMul + Zero + One { type Output = CsrMatrix; fn add(self, rhs: CsrMatrix) -> Self::Output { rhs + self } } impl Add> for CsrMatrix where T: Scalar + ClosedAdd + ClosedMul + Zero + One { type Output = Self; fn add(self, rhs: CsrMatrix) -> Self::Output { self + &rhs } }