Change proptest strategies to use DimRange

This commit is contained in:
Andreas Longva 2021-01-20 17:43:01 +01:00
parent 9cd1540496
commit 31c911d4fb
2 changed files with 30 additions and 25 deletions

View File

@ -9,13 +9,14 @@ mod proptest_patched;
use crate::coo::CooMatrix;
use proptest::prelude::*;
use proptest::collection::{vec, hash_map, btree_set};
use nalgebra::Scalar;
use nalgebra::{Scalar, Dim};
use std::cmp::min;
use std::iter::{repeat};
use proptest::sample::{Index};
use crate::csr::CsrMatrix;
use crate::pattern::SparsityPattern;
use crate::csc::CscMatrix;
use nalgebra::proptest::DimRange;
fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
-> impl Strategy<Value=Vec<(usize, usize)>>
@ -141,14 +142,14 @@ fn sparse_triplet_strategy<T>(value_strategy: T,
/// TODO
pub fn coo_no_duplicates<T>(
value_strategy: T,
rows: impl Strategy<Value=usize> + 'static,
cols: impl Strategy<Value=usize> + 'static,
rows: impl Into<DimRange>,
cols: impl Into<DimRange>,
max_nonzeros: usize) -> impl Strategy<Value=CooMatrix<T::Value>>
where
T: Strategy + Clone + 'static,
T::Value: Scalar,
{
(rows, cols)
(rows.into().to_range_inclusive(), cols.into().to_range_inclusive())
.prop_flat_map(move |(nrows, ncols)| {
let max_nonzeros = min(max_nonzeros, nrows * ncols);
let size_range = 0 ..= max_nonzeros;
@ -182,8 +183,8 @@ where
/// for each triplet, but does not consider the sum of triplets
pub fn coo_with_duplicates<T>(
value_strategy: T,
rows: impl Strategy<Value=usize> + 'static,
cols: impl Strategy<Value=usize> + 'static,
rows: impl Into<DimRange>,
cols: impl Into<DimRange>,
max_nonzeros: usize,
max_duplicates: usize)
-> impl Strategy<Value=CooMatrix<T::Value>>
@ -256,12 +257,12 @@ where
/// TODO
pub fn sparsity_pattern(
major_lanes: impl Strategy<Value=usize> + 'static,
minor_lanes: impl Strategy<Value=usize> + 'static,
major_lanes: impl Into<DimRange>,
minor_lanes: impl Into<DimRange>,
max_nonzeros: usize)
-> impl Strategy<Value=SparsityPattern>
{
(major_lanes, minor_lanes)
(major_lanes.into().to_range_inclusive(), minor_lanes.into().to_range_inclusive())
.prop_flat_map(move |(nmajor, nminor)| {
let max_nonzeros = min(nmajor * nminor, max_nonzeros);
(Just(nmajor), Just(nminor), 0 ..= max_nonzeros)
@ -287,15 +288,17 @@ pub fn sparsity_pattern(
/// TODO
pub fn csr<T>(value_strategy: T,
rows: impl Strategy<Value=usize> + 'static,
cols: impl Strategy<Value=usize> + 'static,
rows: impl Into<DimRange>,
cols: impl Into<DimRange>,
max_nonzeros: usize)
-> impl Strategy<Value=CsrMatrix<T::Value>>
where
T: Strategy + Clone + 'static,
T::Value: Scalar,
{
sparsity_pattern(rows, cols, max_nonzeros)
let rows = rows.into();
let cols = cols.into();
sparsity_pattern(rows.lower_bound().value() ..= rows.upper_bound().value(), cols.lower_bound().value() ..= cols.upper_bound().value(), max_nonzeros)
.prop_flat_map(move |pattern| {
let nnz = pattern.nnz();
let values = vec![value_strategy.clone(); nnz];
@ -309,15 +312,17 @@ where
/// TODO
pub fn csc<T>(value_strategy: T,
rows: impl Strategy<Value=usize> + 'static,
cols: impl Strategy<Value=usize> + 'static,
rows: impl Into<DimRange>,
cols: impl Into<DimRange>,
max_nonzeros: usize)
-> impl Strategy<Value=CscMatrix<T::Value>>
where
T: Strategy + Clone + 'static,
T::Value: Scalar,
{
sparsity_pattern(cols, rows, max_nonzeros)
let rows = rows.into();
let cols = cols.into();
sparsity_pattern(cols.lower_bound().value() ..= cols.upper_bound().value(), rows.lower_bound().value() ..= rows.upper_bound().value(), max_nonzeros)
.prop_flat_map(move |pattern| {
let nnz = pattern.nnz();
let values = vec![value_strategy.clone(); nnz];

View File

@ -70,7 +70,7 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
let b_shape =
if trans_b { (c.ncols(), common_dim) }
else { (common_dim, c.ncols()) };
let a = csr(value_strategy.clone(), Just(a_shape.0), Just(a_shape.1), max_nnz);
let a = csr(value_strategy.clone(), a_shape.0, a_shape.1, max_nnz);
let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1);
// We use the same values for alpha, beta parameters as for matrix elements
@ -179,7 +179,7 @@ fn pattern_strategy() -> impl Strategy<Value=SparsityPattern> {
fn spadd_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
pattern_strategy()
.prop_flat_map(|a| {
let b = sparsity_pattern(Just(a.major_dim()), Just(a.minor_dim()), PROPTEST_MAX_NNZ);
let b = sparsity_pattern(a.major_dim(), a.minor_dim(), PROPTEST_MAX_NNZ);
(Just(a), b)
})
}
@ -188,7 +188,7 @@ fn spadd_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPat
fn spmm_csr_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
pattern_strategy()
.prop_flat_map(|a| {
let b = sparsity_pattern(Just(a.minor_dim()), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
let b = sparsity_pattern(a.minor_dim(), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
(Just(a), b)
})
}
@ -269,7 +269,7 @@ fn csc_invertible_diagonal() -> impl Strategy<Value=CscMatrix<f64>> {
fn csc_square_with_non_zero_diagonals() -> impl Strategy<Value=CscMatrix<f64>> {
csc_invertible_diagonal()
.prop_flat_map(|d| {
csc(value_strategy::<f64>(), Just(d.nrows()), Just(d.nrows()), PROPTEST_MAX_NNZ)
csc(value_strategy::<f64>(), d.nrows(), d.nrows(), PROPTEST_MAX_NNZ)
.prop_map(move |mut c| {
for (i, j, v) in c.triplet_iter_mut() {
if i == j {
@ -412,7 +412,7 @@ proptest! {
(a, b)
in csr_strategy()
.prop_flat_map(|a| {
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
let b = csr(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ);
(Just(a), b)
}))
{
@ -448,7 +448,7 @@ proptest! {
(a, b)
in csr_strategy()
.prop_flat_map(|a| {
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
let b = csr(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ);
(Just(a), b)
}))
{
@ -606,7 +606,7 @@ proptest! {
.prop_flat_map(|a| {
let max_nnz = PROPTEST_MAX_NNZ;
let cols = PROPTEST_MATRIX_DIM;
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.ncols()), cols, max_nnz);
let b = csr(PROPTEST_I32_VALUE_STRATEGY, a.ncols(), cols, max_nnz);
(Just(a), b)
}))
{
@ -713,7 +713,7 @@ proptest! {
.prop_flat_map(|a| {
let max_nnz = PROPTEST_MAX_NNZ;
let cols = PROPTEST_MATRIX_DIM;
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.ncols()), cols, max_nnz);
let b = csc(PROPTEST_I32_VALUE_STRATEGY, a.ncols(), cols, max_nnz);
(Just(a), b)
})
.prop_map(|(a, b)| {
@ -865,7 +865,7 @@ proptest! {
(a, b)
in csc_strategy()
.prop_flat_map(|a| {
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
let b = csc(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ);
(Just(a), b)
}))
{
@ -901,7 +901,7 @@ proptest! {
(a, b)
in csc_strategy()
.prop_flat_map(|a| {
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
let b = csc(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ);
(Just(a), b)
}))
{