Add csr, csc, sparsity_pattern proptest generators (untested)
This commit is contained in:
parent
41ce9a23df
commit
6083d24dd6
|
@ -4,11 +4,42 @@
|
||||||
|
|
||||||
use crate::coo::CooMatrix;
|
use crate::coo::CooMatrix;
|
||||||
use proptest::prelude::*;
|
use proptest::prelude::*;
|
||||||
use proptest::collection::{vec, hash_map};
|
use proptest::collection::{vec, hash_map, btree_set};
|
||||||
use nalgebra::Scalar;
|
use nalgebra::Scalar;
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::iter::repeat;
|
use std::iter::{repeat};
|
||||||
use proptest::sample::{Index};
|
use proptest::sample::{Index};
|
||||||
|
use crate::csr::CsrMatrix;
|
||||||
|
use crate::pattern::SparsityPattern;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use crate::csc::CscMatrix;
|
||||||
|
|
||||||
|
fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
|
||||||
|
-> impl Strategy<Value=Vec<(usize, usize)>>
|
||||||
|
{
|
||||||
|
let mut booleans = vec![true; nnz];
|
||||||
|
booleans.append(&mut vec![false; (nrows * ncols) - nnz]);
|
||||||
|
// Make sure that exactly `nnz` of the booleans are true
|
||||||
|
Just(booleans)
|
||||||
|
// Need to shuffle to make sure they are randomly distributed
|
||||||
|
.prop_shuffle()
|
||||||
|
.prop_map(move |booleans| {
|
||||||
|
booleans
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(index, is_entry)| {
|
||||||
|
if is_entry {
|
||||||
|
// Convert linear index to row/col pair
|
||||||
|
let i = index / ncols;
|
||||||
|
let j = index % ncols;
|
||||||
|
Some((i, j))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// A strategy for generating `nnz` triplets.
|
/// A strategy for generating `nnz` triplets.
|
||||||
///
|
///
|
||||||
|
@ -177,4 +208,112 @@ where
|
||||||
}
|
}
|
||||||
coo
|
coo
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sparsity_pattern_from_row_major_coords<I>(nmajor: usize, nminor: usize, coords: I) -> SparsityPattern
|
||||||
|
where
|
||||||
|
I: Iterator<Item=(usize, usize)> + ExactSizeIterator,
|
||||||
|
{
|
||||||
|
let mut minors = Vec::with_capacity(coords.len());
|
||||||
|
let mut offsets = Vec::with_capacity(nmajor + 1);
|
||||||
|
let mut current_major = 0;
|
||||||
|
offsets.push(0);
|
||||||
|
for (idx, (i, j)) in coords.enumerate() {
|
||||||
|
assert!(i >= current_major);
|
||||||
|
assert!(i < nmajor && j < nminor, "Generated coords are out of bounds");
|
||||||
|
while current_major < i{
|
||||||
|
offsets.push(idx);
|
||||||
|
current_major += 1;
|
||||||
|
}
|
||||||
|
minors.push(j);
|
||||||
|
}
|
||||||
|
|
||||||
|
while current_major < nmajor {
|
||||||
|
offsets.push(minors.len());
|
||||||
|
current_major += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(offsets.first().unwrap(), &0);
|
||||||
|
assert_eq!(offsets.len(), nmajor + 1);
|
||||||
|
|
||||||
|
SparsityPattern::try_from_offsets_and_indices(nmajor,
|
||||||
|
nminor,
|
||||||
|
offsets,
|
||||||
|
minors)
|
||||||
|
.expect("Internal error: Generated sparsity pattern is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO
|
||||||
|
pub fn sparsity_pattern(
|
||||||
|
major_lanes: impl Strategy<Value=usize> + 'static,
|
||||||
|
minor_lanes: impl Strategy<Value=usize> + 'static,
|
||||||
|
max_nonzeros: usize)
|
||||||
|
-> impl Strategy<Value=SparsityPattern>
|
||||||
|
{
|
||||||
|
(major_lanes, minor_lanes)
|
||||||
|
.prop_flat_map(move |(nmajor, nminor)| {
|
||||||
|
let max_nonzeros = min(nmajor * nminor, max_nonzeros);
|
||||||
|
(Just(nmajor), Just(nminor), 0 ..= max_nonzeros)
|
||||||
|
}).prop_flat_map(move |(nmajor, nminor, nnz)| {
|
||||||
|
if 10 * nnz < nmajor * nminor {
|
||||||
|
// If nnz is small compared to a dense matrix, then use a sparse sampling strategy
|
||||||
|
btree_set((0..nmajor, 0..nminor), nnz)
|
||||||
|
.prop_map(move |coords| {
|
||||||
|
sparsity_pattern_from_row_major_coords(nmajor, nminor, coords.into_iter())
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
|
} else {
|
||||||
|
// If the required number of nonzeros is sufficiently dense,
|
||||||
|
// we instead use a dense sampling
|
||||||
|
dense_row_major_coord_strategy(nmajor, nminor, nnz)
|
||||||
|
.prop_map(move |triplets| {
|
||||||
|
let coords = triplets.into_iter();
|
||||||
|
sparsity_pattern_from_row_major_coords(nmajor, nminor, coords)
|
||||||
|
}).boxed()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO
|
||||||
|
pub fn csr<T>(value_strategy: T,
|
||||||
|
rows: impl Strategy<Value=usize> + 'static,
|
||||||
|
cols: impl Strategy<Value=usize> + 'static,
|
||||||
|
max_nonzeros: usize)
|
||||||
|
-> impl Strategy<Value=CsrMatrix<T::Value>>
|
||||||
|
where
|
||||||
|
T: Strategy + Clone + 'static,
|
||||||
|
T::Value: Scalar,
|
||||||
|
{
|
||||||
|
sparsity_pattern(rows, cols, max_nonzeros)
|
||||||
|
.prop_flat_map(move |pattern| {
|
||||||
|
let nnz = pattern.nnz();
|
||||||
|
let values = vec![value_strategy.clone(); nnz];
|
||||||
|
(Just(pattern), values)
|
||||||
|
})
|
||||||
|
.prop_map(|(pattern, values)| {
|
||||||
|
CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||||
|
.expect("Internal error: Generated CsrMatrix is invalid")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO
|
||||||
|
pub fn csc<T>(value_strategy: T,
|
||||||
|
rows: impl Strategy<Value=usize> + 'static,
|
||||||
|
cols: impl Strategy<Value=usize> + 'static,
|
||||||
|
max_nonzeros: usize)
|
||||||
|
-> impl Strategy<Value=CscMatrix<T::Value>>
|
||||||
|
where
|
||||||
|
T: Strategy + Clone + 'static,
|
||||||
|
T::Value: Scalar,
|
||||||
|
{
|
||||||
|
sparsity_pattern(cols, rows, max_nonzeros)
|
||||||
|
.prop_flat_map(move |pattern| {
|
||||||
|
let nnz = pattern.nnz();
|
||||||
|
let values = vec![value_strategy.clone(); nnz];
|
||||||
|
(Just(pattern), values)
|
||||||
|
})
|
||||||
|
.prop_map(|(pattern, values)| {
|
||||||
|
CscMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||||
|
.expect("Internal error: Generated CscMatrix is invalid")
|
||||||
|
})
|
||||||
}
|
}
|
|
@ -131,4 +131,6 @@ mod slow {
|
||||||
// is contained in the set of visited matrices
|
// is contained in the set of visited matrices
|
||||||
assert!(all_combinations.is_subset(&visited_combinations));
|
assert!(all_combinations.is_subset(&visited_combinations));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Tests for csr, csc and sparsity_pattern strategies
|
Loading…
Reference in New Issue