Change proptest strategies to use DimRange
This commit is contained in:
parent
9cd1540496
commit
31c911d4fb
|
@ -9,13 +9,14 @@ mod proptest_patched;
|
|||
use crate::coo::CooMatrix;
|
||||
use proptest::prelude::*;
|
||||
use proptest::collection::{vec, hash_map, btree_set};
|
||||
use nalgebra::Scalar;
|
||||
use nalgebra::{Scalar, Dim};
|
||||
use std::cmp::min;
|
||||
use std::iter::{repeat};
|
||||
use proptest::sample::{Index};
|
||||
use crate::csr::CsrMatrix;
|
||||
use crate::pattern::SparsityPattern;
|
||||
use crate::csc::CscMatrix;
|
||||
use nalgebra::proptest::DimRange;
|
||||
|
||||
fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
|
||||
-> impl Strategy<Value=Vec<(usize, usize)>>
|
||||
|
@ -141,14 +142,14 @@ fn sparse_triplet_strategy<T>(value_strategy: T,
|
|||
/// TODO
|
||||
pub fn coo_no_duplicates<T>(
|
||||
value_strategy: T,
|
||||
rows: impl Strategy<Value=usize> + 'static,
|
||||
cols: impl Strategy<Value=usize> + 'static,
|
||||
rows: impl Into<DimRange>,
|
||||
cols: impl Into<DimRange>,
|
||||
max_nonzeros: usize) -> impl Strategy<Value=CooMatrix<T::Value>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
{
|
||||
(rows, cols)
|
||||
(rows.into().to_range_inclusive(), cols.into().to_range_inclusive())
|
||||
.prop_flat_map(move |(nrows, ncols)| {
|
||||
let max_nonzeros = min(max_nonzeros, nrows * ncols);
|
||||
let size_range = 0 ..= max_nonzeros;
|
||||
|
@ -182,8 +183,8 @@ where
|
|||
/// for each triplet, but does not consider the sum of triplets
|
||||
pub fn coo_with_duplicates<T>(
|
||||
value_strategy: T,
|
||||
rows: impl Strategy<Value=usize> + 'static,
|
||||
cols: impl Strategy<Value=usize> + 'static,
|
||||
rows: impl Into<DimRange>,
|
||||
cols: impl Into<DimRange>,
|
||||
max_nonzeros: usize,
|
||||
max_duplicates: usize)
|
||||
-> impl Strategy<Value=CooMatrix<T::Value>>
|
||||
|
@ -256,12 +257,12 @@ where
|
|||
|
||||
/// TODO
|
||||
pub fn sparsity_pattern(
|
||||
major_lanes: impl Strategy<Value=usize> + 'static,
|
||||
minor_lanes: impl Strategy<Value=usize> + 'static,
|
||||
major_lanes: impl Into<DimRange>,
|
||||
minor_lanes: impl Into<DimRange>,
|
||||
max_nonzeros: usize)
|
||||
-> impl Strategy<Value=SparsityPattern>
|
||||
{
|
||||
(major_lanes, minor_lanes)
|
||||
(major_lanes.into().to_range_inclusive(), minor_lanes.into().to_range_inclusive())
|
||||
.prop_flat_map(move |(nmajor, nminor)| {
|
||||
let max_nonzeros = min(nmajor * nminor, max_nonzeros);
|
||||
(Just(nmajor), Just(nminor), 0 ..= max_nonzeros)
|
||||
|
@ -287,15 +288,17 @@ pub fn sparsity_pattern(
|
|||
|
||||
/// TODO
|
||||
pub fn csr<T>(value_strategy: T,
|
||||
rows: impl Strategy<Value=usize> + 'static,
|
||||
cols: impl Strategy<Value=usize> + 'static,
|
||||
rows: impl Into<DimRange>,
|
||||
cols: impl Into<DimRange>,
|
||||
max_nonzeros: usize)
|
||||
-> impl Strategy<Value=CsrMatrix<T::Value>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
{
|
||||
sparsity_pattern(rows, cols, max_nonzeros)
|
||||
let rows = rows.into();
|
||||
let cols = cols.into();
|
||||
sparsity_pattern(rows.lower_bound().value() ..= rows.upper_bound().value(), cols.lower_bound().value() ..= cols.upper_bound().value(), max_nonzeros)
|
||||
.prop_flat_map(move |pattern| {
|
||||
let nnz = pattern.nnz();
|
||||
let values = vec![value_strategy.clone(); nnz];
|
||||
|
@ -309,15 +312,17 @@ where
|
|||
|
||||
/// TODO
|
||||
pub fn csc<T>(value_strategy: T,
|
||||
rows: impl Strategy<Value=usize> + 'static,
|
||||
cols: impl Strategy<Value=usize> + 'static,
|
||||
rows: impl Into<DimRange>,
|
||||
cols: impl Into<DimRange>,
|
||||
max_nonzeros: usize)
|
||||
-> impl Strategy<Value=CscMatrix<T::Value>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
{
|
||||
sparsity_pattern(cols, rows, max_nonzeros)
|
||||
let rows = rows.into();
|
||||
let cols = cols.into();
|
||||
sparsity_pattern(cols.lower_bound().value() ..= cols.upper_bound().value(), rows.lower_bound().value() ..= rows.upper_bound().value(), max_nonzeros)
|
||||
.prop_flat_map(move |pattern| {
|
||||
let nnz = pattern.nnz();
|
||||
let values = vec![value_strategy.clone(); nnz];
|
||||
|
|
|
@ -70,7 +70,7 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
|
|||
let b_shape =
|
||||
if trans_b { (c.ncols(), common_dim) }
|
||||
else { (common_dim, c.ncols()) };
|
||||
let a = csr(value_strategy.clone(), Just(a_shape.0), Just(a_shape.1), max_nnz);
|
||||
let a = csr(value_strategy.clone(), a_shape.0, a_shape.1, max_nnz);
|
||||
let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1);
|
||||
|
||||
// We use the same values for alpha, beta parameters as for matrix elements
|
||||
|
@ -179,7 +179,7 @@ fn pattern_strategy() -> impl Strategy<Value=SparsityPattern> {
|
|||
fn spadd_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
||||
pattern_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let b = sparsity_pattern(Just(a.major_dim()), Just(a.minor_dim()), PROPTEST_MAX_NNZ);
|
||||
let b = sparsity_pattern(a.major_dim(), a.minor_dim(), PROPTEST_MAX_NNZ);
|
||||
(Just(a), b)
|
||||
})
|
||||
}
|
||||
|
@ -188,7 +188,7 @@ fn spadd_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPat
|
|||
fn spmm_csr_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
||||
pattern_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let b = sparsity_pattern(Just(a.minor_dim()), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
|
||||
let b = sparsity_pattern(a.minor_dim(), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
|
||||
(Just(a), b)
|
||||
})
|
||||
}
|
||||
|
@ -269,7 +269,7 @@ fn csc_invertible_diagonal() -> impl Strategy<Value=CscMatrix<f64>> {
|
|||
fn csc_square_with_non_zero_diagonals() -> impl Strategy<Value=CscMatrix<f64>> {
|
||||
csc_invertible_diagonal()
|
||||
.prop_flat_map(|d| {
|
||||
csc(value_strategy::<f64>(), Just(d.nrows()), Just(d.nrows()), PROPTEST_MAX_NNZ)
|
||||
csc(value_strategy::<f64>(), d.nrows(), d.nrows(), PROPTEST_MAX_NNZ)
|
||||
.prop_map(move |mut c| {
|
||||
for (i, j, v) in c.triplet_iter_mut() {
|
||||
if i == j {
|
||||
|
@ -412,7 +412,7 @@ proptest! {
|
|||
(a, b)
|
||||
in csr_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
|
||||
let b = csr(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ);
|
||||
(Just(a), b)
|
||||
}))
|
||||
{
|
||||
|
@ -448,7 +448,7 @@ proptest! {
|
|||
(a, b)
|
||||
in csr_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
|
||||
let b = csr(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ);
|
||||
(Just(a), b)
|
||||
}))
|
||||
{
|
||||
|
@ -606,7 +606,7 @@ proptest! {
|
|||
.prop_flat_map(|a| {
|
||||
let max_nnz = PROPTEST_MAX_NNZ;
|
||||
let cols = PROPTEST_MATRIX_DIM;
|
||||
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.ncols()), cols, max_nnz);
|
||||
let b = csr(PROPTEST_I32_VALUE_STRATEGY, a.ncols(), cols, max_nnz);
|
||||
(Just(a), b)
|
||||
}))
|
||||
{
|
||||
|
@ -713,7 +713,7 @@ proptest! {
|
|||
.prop_flat_map(|a| {
|
||||
let max_nnz = PROPTEST_MAX_NNZ;
|
||||
let cols = PROPTEST_MATRIX_DIM;
|
||||
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.ncols()), cols, max_nnz);
|
||||
let b = csc(PROPTEST_I32_VALUE_STRATEGY, a.ncols(), cols, max_nnz);
|
||||
(Just(a), b)
|
||||
})
|
||||
.prop_map(|(a, b)| {
|
||||
|
@ -865,7 +865,7 @@ proptest! {
|
|||
(a, b)
|
||||
in csc_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
|
||||
let b = csc(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ);
|
||||
(Just(a), b)
|
||||
}))
|
||||
{
|
||||
|
@ -901,7 +901,7 @@ proptest! {
|
|||
(a, b)
|
||||
in csc_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
|
||||
let b = csc(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ);
|
||||
(Just(a), b)
|
||||
}))
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue