Change proptest strategies to use DimRange

This commit is contained in:
Andreas Longva 2021-01-20 17:43:01 +01:00
parent 9cd1540496
commit 31c911d4fb
2 changed files with 30 additions and 25 deletions

View File

@ -9,13 +9,14 @@ mod proptest_patched;
use crate::coo::CooMatrix; use crate::coo::CooMatrix;
use proptest::prelude::*; use proptest::prelude::*;
use proptest::collection::{vec, hash_map, btree_set}; use proptest::collection::{vec, hash_map, btree_set};
use nalgebra::Scalar; use nalgebra::{Scalar, Dim};
use std::cmp::min; use std::cmp::min;
use std::iter::{repeat}; use std::iter::{repeat};
use proptest::sample::{Index}; use proptest::sample::{Index};
use crate::csr::CsrMatrix; use crate::csr::CsrMatrix;
use crate::pattern::SparsityPattern; use crate::pattern::SparsityPattern;
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use nalgebra::proptest::DimRange;
fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize) fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
-> impl Strategy<Value=Vec<(usize, usize)>> -> impl Strategy<Value=Vec<(usize, usize)>>
@ -141,14 +142,14 @@ fn sparse_triplet_strategy<T>(value_strategy: T,
/// TODO /// TODO
pub fn coo_no_duplicates<T>( pub fn coo_no_duplicates<T>(
value_strategy: T, value_strategy: T,
rows: impl Strategy<Value=usize> + 'static, rows: impl Into<DimRange>,
cols: impl Strategy<Value=usize> + 'static, cols: impl Into<DimRange>,
max_nonzeros: usize) -> impl Strategy<Value=CooMatrix<T::Value>> max_nonzeros: usize) -> impl Strategy<Value=CooMatrix<T::Value>>
where where
T: Strategy + Clone + 'static, T: Strategy + Clone + 'static,
T::Value: Scalar, T::Value: Scalar,
{ {
(rows, cols) (rows.into().to_range_inclusive(), cols.into().to_range_inclusive())
.prop_flat_map(move |(nrows, ncols)| { .prop_flat_map(move |(nrows, ncols)| {
let max_nonzeros = min(max_nonzeros, nrows * ncols); let max_nonzeros = min(max_nonzeros, nrows * ncols);
let size_range = 0 ..= max_nonzeros; let size_range = 0 ..= max_nonzeros;
@ -182,8 +183,8 @@ where
/// for each triplet, but does not consider the sum of triplets /// for each triplet, but does not consider the sum of triplets
pub fn coo_with_duplicates<T>( pub fn coo_with_duplicates<T>(
value_strategy: T, value_strategy: T,
rows: impl Strategy<Value=usize> + 'static, rows: impl Into<DimRange>,
cols: impl Strategy<Value=usize> + 'static, cols: impl Into<DimRange>,
max_nonzeros: usize, max_nonzeros: usize,
max_duplicates: usize) max_duplicates: usize)
-> impl Strategy<Value=CooMatrix<T::Value>> -> impl Strategy<Value=CooMatrix<T::Value>>
@ -256,12 +257,12 @@ where
/// TODO /// TODO
pub fn sparsity_pattern( pub fn sparsity_pattern(
major_lanes: impl Strategy<Value=usize> + 'static, major_lanes: impl Into<DimRange>,
minor_lanes: impl Strategy<Value=usize> + 'static, minor_lanes: impl Into<DimRange>,
max_nonzeros: usize) max_nonzeros: usize)
-> impl Strategy<Value=SparsityPattern> -> impl Strategy<Value=SparsityPattern>
{ {
(major_lanes, minor_lanes) (major_lanes.into().to_range_inclusive(), minor_lanes.into().to_range_inclusive())
.prop_flat_map(move |(nmajor, nminor)| { .prop_flat_map(move |(nmajor, nminor)| {
let max_nonzeros = min(nmajor * nminor, max_nonzeros); let max_nonzeros = min(nmajor * nminor, max_nonzeros);
(Just(nmajor), Just(nminor), 0 ..= max_nonzeros) (Just(nmajor), Just(nminor), 0 ..= max_nonzeros)
@ -287,15 +288,17 @@ pub fn sparsity_pattern(
/// TODO /// TODO
pub fn csr<T>(value_strategy: T, pub fn csr<T>(value_strategy: T,
rows: impl Strategy<Value=usize> + 'static, rows: impl Into<DimRange>,
cols: impl Strategy<Value=usize> + 'static, cols: impl Into<DimRange>,
max_nonzeros: usize) max_nonzeros: usize)
-> impl Strategy<Value=CsrMatrix<T::Value>> -> impl Strategy<Value=CsrMatrix<T::Value>>
where where
T: Strategy + Clone + 'static, T: Strategy + Clone + 'static,
T::Value: Scalar, T::Value: Scalar,
{ {
sparsity_pattern(rows, cols, max_nonzeros) let rows = rows.into();
let cols = cols.into();
sparsity_pattern(rows.lower_bound().value() ..= rows.upper_bound().value(), cols.lower_bound().value() ..= cols.upper_bound().value(), max_nonzeros)
.prop_flat_map(move |pattern| { .prop_flat_map(move |pattern| {
let nnz = pattern.nnz(); let nnz = pattern.nnz();
let values = vec![value_strategy.clone(); nnz]; let values = vec![value_strategy.clone(); nnz];
@ -309,15 +312,17 @@ where
/// TODO /// TODO
pub fn csc<T>(value_strategy: T, pub fn csc<T>(value_strategy: T,
rows: impl Strategy<Value=usize> + 'static, rows: impl Into<DimRange>,
cols: impl Strategy<Value=usize> + 'static, cols: impl Into<DimRange>,
max_nonzeros: usize) max_nonzeros: usize)
-> impl Strategy<Value=CscMatrix<T::Value>> -> impl Strategy<Value=CscMatrix<T::Value>>
where where
T: Strategy + Clone + 'static, T: Strategy + Clone + 'static,
T::Value: Scalar, T::Value: Scalar,
{ {
sparsity_pattern(cols, rows, max_nonzeros) let rows = rows.into();
let cols = cols.into();
sparsity_pattern(cols.lower_bound().value() ..= cols.upper_bound().value(), rows.lower_bound().value() ..= rows.upper_bound().value(), max_nonzeros)
.prop_flat_map(move |pattern| { .prop_flat_map(move |pattern| {
let nnz = pattern.nnz(); let nnz = pattern.nnz();
let values = vec![value_strategy.clone(); nnz]; let values = vec![value_strategy.clone(); nnz];

View File

@ -70,7 +70,7 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
let b_shape = let b_shape =
if trans_b { (c.ncols(), common_dim) } if trans_b { (c.ncols(), common_dim) }
else { (common_dim, c.ncols()) }; else { (common_dim, c.ncols()) };
let a = csr(value_strategy.clone(), Just(a_shape.0), Just(a_shape.1), max_nnz); let a = csr(value_strategy.clone(), a_shape.0, a_shape.1, max_nnz);
let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1); let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1);
// We use the same values for alpha, beta parameters as for matrix elements // We use the same values for alpha, beta parameters as for matrix elements
@ -179,7 +179,7 @@ fn pattern_strategy() -> impl Strategy<Value=SparsityPattern> {
fn spadd_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> { fn spadd_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
pattern_strategy() pattern_strategy()
.prop_flat_map(|a| { .prop_flat_map(|a| {
let b = sparsity_pattern(Just(a.major_dim()), Just(a.minor_dim()), PROPTEST_MAX_NNZ); let b = sparsity_pattern(a.major_dim(), a.minor_dim(), PROPTEST_MAX_NNZ);
(Just(a), b) (Just(a), b)
}) })
} }
@ -188,7 +188,7 @@ fn spadd_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPat
fn spmm_csr_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> { fn spmm_csr_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
pattern_strategy() pattern_strategy()
.prop_flat_map(|a| { .prop_flat_map(|a| {
let b = sparsity_pattern(Just(a.minor_dim()), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ); let b = sparsity_pattern(a.minor_dim(), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
(Just(a), b) (Just(a), b)
}) })
} }
@ -269,7 +269,7 @@ fn csc_invertible_diagonal() -> impl Strategy<Value=CscMatrix<f64>> {
fn csc_square_with_non_zero_diagonals() -> impl Strategy<Value=CscMatrix<f64>> { fn csc_square_with_non_zero_diagonals() -> impl Strategy<Value=CscMatrix<f64>> {
csc_invertible_diagonal() csc_invertible_diagonal()
.prop_flat_map(|d| { .prop_flat_map(|d| {
csc(value_strategy::<f64>(), Just(d.nrows()), Just(d.nrows()), PROPTEST_MAX_NNZ) csc(value_strategy::<f64>(), d.nrows(), d.nrows(), PROPTEST_MAX_NNZ)
.prop_map(move |mut c| { .prop_map(move |mut c| {
for (i, j, v) in c.triplet_iter_mut() { for (i, j, v) in c.triplet_iter_mut() {
if i == j { if i == j {
@ -412,7 +412,7 @@ proptest! {
(a, b) (a, b)
in csr_strategy() in csr_strategy()
.prop_flat_map(|a| { .prop_flat_map(|a| {
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ); let b = csr(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ);
(Just(a), b) (Just(a), b)
})) }))
{ {
@ -448,7 +448,7 @@ proptest! {
(a, b) (a, b)
in csr_strategy() in csr_strategy()
.prop_flat_map(|a| { .prop_flat_map(|a| {
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ); let b = csr(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ);
(Just(a), b) (Just(a), b)
})) }))
{ {
@ -606,7 +606,7 @@ proptest! {
.prop_flat_map(|a| { .prop_flat_map(|a| {
let max_nnz = PROPTEST_MAX_NNZ; let max_nnz = PROPTEST_MAX_NNZ;
let cols = PROPTEST_MATRIX_DIM; let cols = PROPTEST_MATRIX_DIM;
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.ncols()), cols, max_nnz); let b = csr(PROPTEST_I32_VALUE_STRATEGY, a.ncols(), cols, max_nnz);
(Just(a), b) (Just(a), b)
})) }))
{ {
@ -713,7 +713,7 @@ proptest! {
.prop_flat_map(|a| { .prop_flat_map(|a| {
let max_nnz = PROPTEST_MAX_NNZ; let max_nnz = PROPTEST_MAX_NNZ;
let cols = PROPTEST_MATRIX_DIM; let cols = PROPTEST_MATRIX_DIM;
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.ncols()), cols, max_nnz); let b = csc(PROPTEST_I32_VALUE_STRATEGY, a.ncols(), cols, max_nnz);
(Just(a), b) (Just(a), b)
}) })
.prop_map(|(a, b)| { .prop_map(|(a, b)| {
@ -865,7 +865,7 @@ proptest! {
(a, b) (a, b)
in csc_strategy() in csc_strategy()
.prop_flat_map(|a| { .prop_flat_map(|a| {
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ); let b = csc(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ);
(Just(a), b) (Just(a), b)
})) }))
{ {
@ -901,7 +901,7 @@ proptest! {
(a, b) (a, b)
in csc_strategy() in csc_strategy()
.prop_flat_map(|a| { .prop_flat_map(|a| {
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ); let b = csc(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ);
(Just(a), b) (Just(a), b)
})) }))
{ {