Minor refactoring for sp* ops
This commit is contained in:
parent
66cbd26702
commit
8983027b39
|
@ -31,6 +31,23 @@ where
|
||||||
assert_compatible_spmm_dims!(c, a, b);
|
assert_compatible_spmm_dims!(c, a, b);
|
||||||
|
|
||||||
match a {
|
match a {
|
||||||
|
Op::NoOp(ref a) => {
|
||||||
|
for j in 0..c.ncols() {
|
||||||
|
let mut c_col_j = c.column_mut(j);
|
||||||
|
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.row_iter()) {
|
||||||
|
let mut dot_ij = T::zero();
|
||||||
|
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
||||||
|
let b_contrib =
|
||||||
|
match b {
|
||||||
|
Op::NoOp(ref b) => b.index((k, j)),
|
||||||
|
Op::Transpose(ref b) => b.index((j, k))
|
||||||
|
};
|
||||||
|
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
|
||||||
|
}
|
||||||
|
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
Op::Transpose(ref a) => {
|
Op::Transpose(ref a) => {
|
||||||
// In this case, we have to pre-multiply C by beta
|
// In this case, we have to pre-multiply C by beta
|
||||||
c *= beta;
|
c *= beta;
|
||||||
|
@ -57,23 +74,6 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Op::NoOp(ref a) => {
|
|
||||||
for j in 0..c.ncols() {
|
|
||||||
let mut c_col_j = c.column_mut(j);
|
|
||||||
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.row_iter()) {
|
|
||||||
let mut dot_ij = T::zero();
|
|
||||||
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
|
||||||
let b_contrib =
|
|
||||||
match b {
|
|
||||||
Op::NoOp(ref b) => b.index((k, j)),
|
|
||||||
Op::Transpose(ref b) => b.index((j, k))
|
|
||||||
};
|
|
||||||
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
|
|
||||||
}
|
|
||||||
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,46 +107,48 @@ where
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
if let Op::Transpose(a) = a
|
match a {
|
||||||
{
|
Op::NoOp(a) => {
|
||||||
if beta != T::one() {
|
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) {
|
||||||
for c_ij in c.values_mut() {
|
if beta != T::one() {
|
||||||
*c_ij *= beta.inlined_clone();
|
for c_ij in c_row_i.values_mut() {
|
||||||
}
|
*c_ij *= beta.inlined_clone();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (i, a_row_i) in a.row_iter().enumerate() {
|
let (mut c_cols, mut c_vals) = c_row_i.cols_and_values_mut();
|
||||||
for (&j, a_val) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
let (a_cols, a_vals) = (a_row_i.col_indices(), a_row_i.values());
|
||||||
let a_val = a_val.inlined_clone();
|
|
||||||
let alpha = alpha.inlined_clone();
|
for (a_col, a_val) in a_cols.iter().zip(a_vals) {
|
||||||
match c.index_entry_mut(j, i) {
|
// TODO: Use exponential search instead of linear search.
|
||||||
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
|
// If C has substantially more entries in the row than A, then a line search
|
||||||
SparseEntryMut::Zero => return Err(spadd_csr_unexpected_entry()),
|
// will needlessly visit many entries in C.
|
||||||
|
let (c_idx, _) = c_cols.iter()
|
||||||
|
.enumerate()
|
||||||
|
.find(|(_, c_col)| *c_col == a_col)
|
||||||
|
.ok_or_else(spadd_csr_unexpected_entry)?;
|
||||||
|
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
|
||||||
|
c_cols = &c_cols[c_idx ..];
|
||||||
|
c_vals = &mut c_vals[c_idx ..];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if let Op::NoOp(a) = a {
|
Op::Transpose(a) => {
|
||||||
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) {
|
|
||||||
if beta != T::one() {
|
if beta != T::one() {
|
||||||
for c_ij in c_row_i.values_mut() {
|
for c_ij in c.values_mut() {
|
||||||
*c_ij *= beta.inlined_clone();
|
*c_ij *= beta.inlined_clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (mut c_cols, mut c_vals) = c_row_i.cols_and_values_mut();
|
for (i, a_row_i) in a.row_iter().enumerate() {
|
||||||
let (a_cols, a_vals) = (a_row_i.col_indices(), a_row_i.values());
|
for (&j, a_val) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
||||||
|
let a_val = a_val.inlined_clone();
|
||||||
for (a_col, a_val) in a_cols.iter().zip(a_vals) {
|
let alpha = alpha.inlined_clone();
|
||||||
// TODO: Use exponential search instead of linear search.
|
match c.index_entry_mut(j, i) {
|
||||||
// If C has substantially more entries in the row than A, then a line search
|
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
|
||||||
// will needlessly visit many entries in C.
|
SparseEntryMut::Zero => return Err(spadd_csr_unexpected_entry()),
|
||||||
let (c_idx, _) = c_cols.iter()
|
}
|
||||||
.enumerate()
|
}
|
||||||
.find(|(_, c_col)| *c_col == a_col)
|
|
||||||
.ok_or_else(spadd_csr_unexpected_entry)?;
|
|
||||||
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
|
|
||||||
c_cols = &c_cols[c_idx ..];
|
|
||||||
c_vals = &mut c_vals[c_idx ..];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue