Minor refactoring for sp* ops

This commit is contained in:
Andreas Longva 2020-12-21 17:09:26 +01:00
parent 66cbd26702
commit 8983027b39

View File

@ -31,6 +31,23 @@ where
assert_compatible_spmm_dims!(c, a, b);
match a {
Op::NoOp(ref a) => {
for j in 0..c.ncols() {
let mut c_col_j = c.column_mut(j);
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.row_iter()) {
let mut dot_ij = T::zero();
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
let b_contrib =
match b {
Op::NoOp(ref b) => b.index((k, j)),
Op::Transpose(ref b) => b.index((j, k))
};
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
}
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;
}
}
},
Op::Transpose(ref a) => {
// In this case, we have to pre-multiply C by beta
c *= beta;
@ -57,23 +74,6 @@ where
}
}
},
Op::NoOp(ref a) => {
for j in 0..c.ncols() {
let mut c_col_j = c.column_mut(j);
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.row_iter()) {
let mut dot_ij = T::zero();
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
let b_contrib =
match b {
Op::NoOp(ref b) => b.index((k, j)),
Op::Transpose(ref b) => b.index((j, k))
};
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
}
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;
}
}
}
}
}
@ -107,25 +107,8 @@ where
}
Ok(())
} else {
if let Op::Transpose(a) = a
{
if beta != T::one() {
for c_ij in c.values_mut() {
*c_ij *= beta.inlined_clone();
}
}
for (i, a_row_i) in a.row_iter().enumerate() {
for (&j, a_val) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
let a_val = a_val.inlined_clone();
let alpha = alpha.inlined_clone();
match c.index_entry_mut(j, i) {
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
SparseEntryMut::Zero => return Err(spadd_csr_unexpected_entry()),
}
}
}
} else if let Op::NoOp(a) = a {
match a {
Op::NoOp(a) => {
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) {
if beta != T::one() {
for c_ij in c_row_i.values_mut() {
@ -150,6 +133,25 @@ where
}
}
}
Op::Transpose(a) => {
if beta != T::one() {
for c_ij in c.values_mut() {
*c_ij *= beta.inlined_clone();
}
}
for (i, a_row_i) in a.row_iter().enumerate() {
for (&j, a_val) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
let a_val = a_val.inlined_clone();
let alpha = alpha.inlined_clone();
match c.index_entry_mut(j, i) {
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
SparseEntryMut::Zero => return Err(spadd_csr_unexpected_entry()),
}
}
}
}
}
Ok(())
}
}