forked from M-Labs/nalgebra
add spmm example and change the kernel
This commit is contained in:
parent
39bb572557
commit
776fef26c3
@ -24,6 +24,7 @@ io = [ "pest", "pest_derive" ]
|
|||||||
slow-tests = []
|
slow-tests = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
criterion = { version = "0.3", features = ["html_reports"] }
|
||||||
nalgebra = { version="0.30", path = "../" }
|
nalgebra = { version="0.30", path = "../" }
|
||||||
num-traits = { version = "0.2", default-features = false }
|
num-traits = { version = "0.2", default-features = false }
|
||||||
proptest = { version = "1.0", optional = true }
|
proptest = { version = "1.0", optional = true }
|
||||||
@ -31,6 +32,7 @@ matrixcompare-core = { version = "0.1.0", optional = true }
|
|||||||
pest = { version = "2", optional = true }
|
pest = { version = "2", optional = true }
|
||||||
pest_derive = { version = "2", optional = true }
|
pest_derive = { version = "2", optional = true }
|
||||||
serde = { version = "1.0", default-features = false, features = [ "derive" ], optional = true }
|
serde = { version = "1.0", default-features = false, features = [ "derive" ], optional = true }
|
||||||
|
itertools = "0.10"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
itertools = "0.10"
|
itertools = "0.10"
|
||||||
@ -38,6 +40,11 @@ matrixcompare = { version = "0.3.0", features = [ "proptest-support" ] }
|
|||||||
nalgebra = { version="0.30", path = "../", features = ["compare"] }
|
nalgebra = { version="0.30", path = "../", features = ["compare"] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "spmm"
|
||||||
|
required-features = ["io"]
|
||||||
|
path = "example/spmm.rs"
|
||||||
|
|
||||||
[package.metadata.docs.rs]
|
[package.metadata.docs.rs]
|
||||||
# Enable certain features when building docs for docs.rs
|
# Enable certain features when building docs for docs.rs
|
||||||
features = [ "proptest-support", "compare" ]
|
features = [ "proptest-support", "compare" ]
|
||||||
|
@ -1,16 +1,19 @@
|
|||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use crate::cs::CsMatrix;
|
use crate::cs::CsMatrix;
|
||||||
use crate::ops::serial::{OperationError, OperationErrorKind};
|
use crate::ops::serial::{OperationError, OperationErrorKind};
|
||||||
use crate::ops::Op;
|
use crate::ops::Op;
|
||||||
use crate::SparseEntryMut;
|
use crate::SparseEntryMut;
|
||||||
|
use itertools::Itertools;
|
||||||
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
|
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
|
||||||
use num_traits::{One, Zero};
|
use num_traits::{One, Zero};
|
||||||
|
|
||||||
fn spmm_cs_unexpected_entry() -> OperationError {
|
//fn spmm_cs_unexpected_entry() -> OperationError {
|
||||||
OperationError::from_kind_and_message(
|
// OperationError::from_kind_and_message(
|
||||||
OperationErrorKind::InvalidPattern,
|
// OperationErrorKind::InvalidPattern,
|
||||||
String::from("Found unexpected entry that is not present in `c`."),
|
// String::from("Found unexpected entry that is not present in `c`."),
|
||||||
)
|
// )
|
||||||
}
|
//}
|
||||||
|
|
||||||
/// Helper functionality for implementing CSR/CSC SPMM.
|
/// Helper functionality for implementing CSR/CSC SPMM.
|
||||||
///
|
///
|
||||||
@ -32,28 +35,54 @@ where
|
|||||||
{
|
{
|
||||||
for i in 0..c.pattern().major_dim() {
|
for i in 0..c.pattern().major_dim() {
|
||||||
let a_lane_i = a.get_lane(i).unwrap();
|
let a_lane_i = a.get_lane(i).unwrap();
|
||||||
|
|
||||||
|
let some_val = Zero::zero();
|
||||||
|
let mut scratchpad_values: Vec<T> = vec![some_val; b.pattern().minor_dim()];
|
||||||
|
let mut scratchpad_indices: HashSet<usize> = HashSet::new();
|
||||||
|
|
||||||
let mut c_lane_i = c.get_lane_mut(i).unwrap();
|
let mut c_lane_i = c.get_lane_mut(i).unwrap();
|
||||||
for c_ij in c_lane_i.values_mut() {
|
//let (indices, values) = c_lane_i.indices_and_values_mut();
|
||||||
*c_ij = beta.clone() * c_ij.clone();
|
//indices
|
||||||
}
|
// .iter()
|
||||||
|
// .zip(values.iter())
|
||||||
|
// .for_each(|(id, val)| scratchpad_values[*id] = beta.clone() * val.clone());
|
||||||
|
|
||||||
|
//for (index, c_ij) in c_lane_i.indices_and_values_mut() {
|
||||||
|
// *c_ij = beta.clone() * c_ij.clone();
|
||||||
|
//}
|
||||||
|
|
||||||
for (&k, a_ik) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
|
for (&k, a_ik) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
|
||||||
let b_lane_k = b.get_lane(k).unwrap();
|
let b_lane_k = b.get_lane(k).unwrap();
|
||||||
let (mut c_lane_i_cols, mut c_lane_i_values) = c_lane_i.indices_and_values_mut();
|
//let (mut c_lane_i_cols, mut c_lane_i_values) = c_lane_i.indices_and_values_mut();
|
||||||
let alpha_aik = alpha.clone() * a_ik.clone();
|
let alpha_aik = alpha.clone() * a_ik.clone();
|
||||||
for (j, b_kj) in b_lane_k.minor_indices().iter().zip(b_lane_k.values()) {
|
for (j, b_kj) in b_lane_k.minor_indices().iter().zip(b_lane_k.values()) {
|
||||||
// Determine the location in C to append the value
|
// Determine the location in C to append the value
|
||||||
let (c_local_idx, _) = c_lane_i_cols
|
// TODO make a scratchpad and defer the accumulation into C after processing one
|
||||||
.iter()
|
// full row of A.
|
||||||
.enumerate()
|
scratchpad_values[*j] += alpha_aik.clone() * b_kj.clone();
|
||||||
.find(|(_, c_col)| *c_col == j)
|
scratchpad_indices.insert(*j);
|
||||||
.ok_or_else(spmm_cs_unexpected_entry)?;
|
//let (c_local_idx, _) = c_lane_i_cols
|
||||||
|
// .iter()
|
||||||
|
// .enumerate()
|
||||||
|
// .find(|(_, c_col)| *c_col == j)
|
||||||
|
// .ok_or_else(spmm_cs_unexpected_entry)?;
|
||||||
|
|
||||||
c_lane_i_values[c_local_idx] += alpha_aik.clone() * b_kj.clone();
|
//c_lane_i_values[c_local_idx] += alpha_aik.clone() * b_kj.clone();
|
||||||
c_lane_i_cols = &c_lane_i_cols[c_local_idx..];
|
//c_lane_i_cols = &c_lane_i_cols[c_local_idx..];
|
||||||
c_lane_i_values = &mut c_lane_i_values[c_local_idx..];
|
//c_lane_i_values = &mut c_lane_i_values[c_local_idx..];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// sort the indices, and then access the relevant indices (in sorted order) from values
|
||||||
|
// into C.
|
||||||
|
let sorted_indices: Vec<usize> =
|
||||||
|
Itertools::sorted(scratchpad_indices.into_iter()).collect();
|
||||||
|
c_lane_i
|
||||||
|
.values_mut()
|
||||||
|
.iter_mut()
|
||||||
|
.zip(sorted_indices.into_iter())
|
||||||
|
.for_each(|(output_ref, index)| {
|
||||||
|
*output_ref = beta.clone() * output_ref.clone() + scratchpad_values[index].clone()
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
Loading…
Reference in New Issue
Block a user