prealloc everything, remove hashset, make it 4x faster

This commit is contained in:
Saurabh 2022-02-18 11:22:43 -07:00
parent e7d8a00836
commit ff3d1e4e35
2 changed files with 23 additions and 32 deletions

View File

@ -48,3 +48,9 @@ path = "example/spmm.rs"
[package.metadata.docs.rs] [package.metadata.docs.rs]
# Enable certain features when building docs for docs.rs # Enable certain features when building docs for docs.rs
features = [ "proptest-support", "compare" ] features = [ "proptest-support", "compare" ]
[profile.release]
opt-level = 3
lto = "fat"
codegen-units = 1
panic = "abort"

View File

@ -1,10 +1,9 @@
use std::collections::HashSet; //use std::collections::HashSet;
use crate::cs::CsMatrix; use crate::cs::CsMatrix;
use crate::ops::serial::{OperationError, OperationErrorKind}; use crate::ops::serial::{OperationError, OperationErrorKind};
use crate::ops::Op; use crate::ops::Op;
use crate::SparseEntryMut; use crate::SparseEntryMut;
use itertools::Itertools;
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar}; use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
use num_traits::{One, Zero}; use num_traits::{One, Zero};
@ -33,56 +32,42 @@ pub fn spmm_cs_prealloc<T>(
where where
T: Scalar + ClosedAdd + ClosedMul + Zero + One, T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{ {
let some_val = Zero::zero();
let mut scratchpad_values: Vec<T> = vec![some_val; b.pattern().minor_dim()];
let mut scratchpad_indices: Vec<usize> = vec![0; b.pattern().minor_dim()];
let mut scratchpad_used: Vec<bool> = vec![false; b.pattern().minor_dim()];
let mut right_end = 0usize;
for i in 0..c.pattern().major_dim() { for i in 0..c.pattern().major_dim() {
let a_lane_i = a.get_lane(i).unwrap(); let a_lane_i = a.get_lane(i).unwrap();
let some_val = Zero::zero();
let mut scratchpad_values: Vec<T> = vec![some_val; b.pattern().minor_dim()];
let mut scratchpad_indices: HashSet<usize> = HashSet::new();
let mut c_lane_i = c.get_lane_mut(i).unwrap(); let mut c_lane_i = c.get_lane_mut(i).unwrap();
//let (indices, values) = c_lane_i.indices_and_values_mut();
//indices
// .iter()
// .zip(values.iter())
// .for_each(|(id, val)| scratchpad_values[*id] = beta.clone() * val.clone());
//for (index, c_ij) in c_lane_i.indices_and_values_mut() {
// *c_ij = beta.clone() * c_ij.clone();
//}
for (&k, a_ik) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) { for (&k, a_ik) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
let b_lane_k = b.get_lane(k).unwrap(); let b_lane_k = b.get_lane(k).unwrap();
//let (mut c_lane_i_cols, mut c_lane_i_values) = c_lane_i.indices_and_values_mut();
let alpha_aik = alpha.clone() * a_ik.clone(); let alpha_aik = alpha.clone() * a_ik.clone();
for (j, b_kj) in b_lane_k.minor_indices().iter().zip(b_lane_k.values()) { for (j, b_kj) in b_lane_k.minor_indices().iter().zip(b_lane_k.values()) {
// Determine the location in C to append the value // Determine the location in C to append the value
// TODO make a scratchpad and defer the accumulation into C after processing one
// full row of A.
scratchpad_values[*j] += alpha_aik.clone() * b_kj.clone(); scratchpad_values[*j] += alpha_aik.clone() * b_kj.clone();
scratchpad_indices.insert(*j); if !scratchpad_used[*j] {
//let (c_local_idx, _) = c_lane_i_cols scratchpad_indices[right_end] = *j;
// .iter() right_end += 1;
// .enumerate() scratchpad_used[*j] = true;
// .find(|(_, c_col)| *c_col == j) }
// .ok_or_else(spmm_cs_unexpected_entry)?;
//c_lane_i_values[c_local_idx] += alpha_aik.clone() * b_kj.clone();
//c_lane_i_cols = &c_lane_i_cols[c_local_idx..];
//c_lane_i_values = &mut c_lane_i_values[c_local_idx..];
} }
} }
// sort the indices, and then access the relevant indices (in sorted order) from values // sort the indices, and then access the relevant indices (in sorted order) from values
// into C. // into C.
let sorted_indices: Vec<usize> = scratchpad_indices[0..right_end].sort_unstable();
Itertools::sorted(scratchpad_indices.into_iter()).collect();
c_lane_i c_lane_i
.values_mut() .values_mut()
.iter_mut() .iter_mut()
.zip(sorted_indices.into_iter()) .zip(scratchpad_indices[0..right_end].iter())
.for_each(|(output_ref, index)| { .for_each(|(output_ref, index)| {
*output_ref = beta.clone() * output_ref.clone() + scratchpad_values[index].clone() *output_ref = beta.clone() * output_ref.clone() + scratchpad_values[*index].clone();
scratchpad_used[*index] = false;
scratchpad_values[*index] = Zero::zero();
}); });
right_end = 0usize;
} }
Ok(()) Ok(())