diff --git a/nalgebra-sparse/src/ops/serial/csr.rs b/nalgebra-sparse/src/ops/serial/csr.rs index 4dd3ae61..da975c7a 100644 --- a/nalgebra-sparse/src/ops/serial/csr.rs +++ b/nalgebra-sparse/src/ops/serial/csr.rs @@ -31,6 +31,23 @@ where assert_compatible_spmm_dims!(c, a, b); match a { + Op::NoOp(ref a) => { + for j in 0..c.ncols() { + let mut c_col_j = c.column_mut(j); + for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.row_iter()) { + let mut dot_ij = T::zero(); + for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) { + let b_contrib = + match b { + Op::NoOp(ref b) => b.index((k, j)), + Op::Transpose(ref b) => b.index((j, k)) + }; + dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone(); + } + *c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij; + } + } + }, Op::Transpose(ref a) => { // In this case, we have to pre-multiply C by beta c *= beta; @@ -57,23 +74,6 @@ where } } }, - Op::NoOp(ref a) => { - for j in 0..c.ncols() { - let mut c_col_j = c.column_mut(j); - for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.row_iter()) { - let mut dot_ij = T::zero(); - for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) { - let b_contrib = - match b { - Op::NoOp(ref b) => b.index((k, j)), - Op::Transpose(ref b) => b.index((j, k)) - }; - dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone(); - } - *c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij; - } - } - } } } @@ -107,46 +107,48 @@ where } Ok(()) } else { - if let Op::Transpose(a) = a - { - if beta != T::one() { - for c_ij in c.values_mut() { - *c_ij *= beta.inlined_clone(); - } - } + match a { + Op::NoOp(a) => { + for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) { + if beta != T::one() { + for c_ij in c_row_i.values_mut() { + *c_ij *= beta.inlined_clone(); + } + } - for (i, a_row_i) in a.row_iter().enumerate() { - for (&j, a_val) in a_row_i.col_indices().iter().zip(a_row_i.values()) { - let a_val = a_val.inlined_clone(); - let alpha = alpha.inlined_clone(); - match c.index_entry_mut(j, i) { - SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val } - SparseEntryMut::Zero => return Err(spadd_csr_unexpected_entry()), + let (mut c_cols, mut c_vals) = c_row_i.cols_and_values_mut(); + let (a_cols, a_vals) = (a_row_i.col_indices(), a_row_i.values()); + + for (a_col, a_val) in a_cols.iter().zip(a_vals) { + // TODO: Use exponential search instead of linear search. + // If C has substantially more entries in the row than A, then a line search + // will needlessly visit many entries in C. + let (c_idx, _) = c_cols.iter() + .enumerate() + .find(|(_, c_col)| *c_col == a_col) + .ok_or_else(spadd_csr_unexpected_entry)?; + c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone(); + c_cols = &c_cols[c_idx ..]; + c_vals = &mut c_vals[c_idx ..]; } } } - } else if let Op::NoOp(a) = a { - for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) { + Op::Transpose(a) => { if beta != T::one() { - for c_ij in c_row_i.values_mut() { + for c_ij in c.values_mut() { *c_ij *= beta.inlined_clone(); } } - let (mut c_cols, mut c_vals) = c_row_i.cols_and_values_mut(); - let (a_cols, a_vals) = (a_row_i.col_indices(), a_row_i.values()); - - for (a_col, a_val) in a_cols.iter().zip(a_vals) { - // TODO: Use exponential search instead of linear search. - // If C has substantially more entries in the row than A, then a line search - // will needlessly visit many entries in C. - let (c_idx, _) = c_cols.iter() - .enumerate() - .find(|(_, c_col)| *c_col == a_col) - .ok_or_else(spadd_csr_unexpected_entry)?; - c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone(); - c_cols = &c_cols[c_idx ..]; - c_vals = &mut c_vals[c_idx ..]; + for (i, a_row_i) in a.row_iter().enumerate() { + for (&j, a_val) in a_row_i.col_indices().iter().zip(a_row_i.values()) { + let a_val = a_val.inlined_clone(); + let alpha = alpha.inlined_clone(); + match c.index_entry_mut(j, i) { + SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val } + SparseEntryMut::Zero => return Err(spadd_csr_unexpected_entry()), + } + } } } }