diff --git a/nalgebra-sparse/src/proptest.rs b/nalgebra-sparse/src/proptest.rs index 757c539b..143696ad 100644 --- a/nalgebra-sparse/src/proptest.rs +++ b/nalgebra-sparse/src/proptest.rs @@ -9,13 +9,14 @@ mod proptest_patched; use crate::coo::CooMatrix; use proptest::prelude::*; use proptest::collection::{vec, hash_map, btree_set}; -use nalgebra::Scalar; +use nalgebra::{Scalar, Dim}; use std::cmp::min; use std::iter::{repeat}; use proptest::sample::{Index}; use crate::csr::CsrMatrix; use crate::pattern::SparsityPattern; use crate::csc::CscMatrix; +use nalgebra::proptest::DimRange; fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize) -> impl Strategy> @@ -141,14 +142,14 @@ fn sparse_triplet_strategy(value_strategy: T, /// TODO pub fn coo_no_duplicates( value_strategy: T, - rows: impl Strategy + 'static, - cols: impl Strategy + 'static, + rows: impl Into, + cols: impl Into, max_nonzeros: usize) -> impl Strategy> where T: Strategy + Clone + 'static, T::Value: Scalar, { - (rows, cols) + (rows.into().to_range_inclusive(), cols.into().to_range_inclusive()) .prop_flat_map(move |(nrows, ncols)| { let max_nonzeros = min(max_nonzeros, nrows * ncols); let size_range = 0 ..= max_nonzeros; @@ -182,8 +183,8 @@ where /// for each triplet, but does not consider the sum of triplets pub fn coo_with_duplicates( value_strategy: T, - rows: impl Strategy + 'static, - cols: impl Strategy + 'static, + rows: impl Into, + cols: impl Into, max_nonzeros: usize, max_duplicates: usize) -> impl Strategy> @@ -256,12 +257,12 @@ where /// TODO pub fn sparsity_pattern( - major_lanes: impl Strategy + 'static, - minor_lanes: impl Strategy + 'static, + major_lanes: impl Into, + minor_lanes: impl Into, max_nonzeros: usize) -> impl Strategy { - (major_lanes, minor_lanes) + (major_lanes.into().to_range_inclusive(), minor_lanes.into().to_range_inclusive()) .prop_flat_map(move |(nmajor, nminor)| { let max_nonzeros = min(nmajor * nminor, max_nonzeros); (Just(nmajor), Just(nminor), 0 ..= max_nonzeros) @@ -287,15 +288,17 @@ pub fn sparsity_pattern( /// TODO pub fn csr(value_strategy: T, - rows: impl Strategy + 'static, - cols: impl Strategy + 'static, + rows: impl Into, + cols: impl Into, max_nonzeros: usize) -> impl Strategy> where T: Strategy + Clone + 'static, T::Value: Scalar, { - sparsity_pattern(rows, cols, max_nonzeros) + let rows = rows.into(); + let cols = cols.into(); + sparsity_pattern(rows.lower_bound().value() ..= rows.upper_bound().value(), cols.lower_bound().value() ..= cols.upper_bound().value(), max_nonzeros) .prop_flat_map(move |pattern| { let nnz = pattern.nnz(); let values = vec![value_strategy.clone(); nnz]; @@ -309,15 +312,17 @@ where /// TODO pub fn csc(value_strategy: T, - rows: impl Strategy + 'static, - cols: impl Strategy + 'static, + rows: impl Into, + cols: impl Into, max_nonzeros: usize) -> impl Strategy> where T: Strategy + Clone + 'static, T::Value: Scalar, { - sparsity_pattern(cols, rows, max_nonzeros) + let rows = rows.into(); + let cols = cols.into(); + sparsity_pattern(cols.lower_bound().value() ..= cols.upper_bound().value(), rows.lower_bound().value() ..= rows.upper_bound().value(), max_nonzeros) .prop_flat_map(move |pattern| { let nnz = pattern.nnz(); let values = vec![value_strategy.clone(); nnz]; diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index cef71378..3b91c3be 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -70,7 +70,7 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy> let b_shape = if trans_b { (c.ncols(), common_dim) } else { (common_dim, c.ncols()) }; - let a = csr(value_strategy.clone(), Just(a_shape.0), Just(a_shape.1), max_nnz); + let a = csr(value_strategy.clone(), a_shape.0, a_shape.1, max_nnz); let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1); // We use the same values for alpha, beta parameters as for matrix elements @@ -179,7 +179,7 @@ fn pattern_strategy() -> impl Strategy { fn spadd_pattern_strategy() -> impl Strategy { pattern_strategy() .prop_flat_map(|a| { - let b = sparsity_pattern(Just(a.major_dim()), Just(a.minor_dim()), PROPTEST_MAX_NNZ); + let b = sparsity_pattern(a.major_dim(), a.minor_dim(), PROPTEST_MAX_NNZ); (Just(a), b) }) } @@ -188,7 +188,7 @@ fn spadd_pattern_strategy() -> impl Strategy impl Strategy { pattern_strategy() .prop_flat_map(|a| { - let b = sparsity_pattern(Just(a.minor_dim()), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ); + let b = sparsity_pattern(a.minor_dim(), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ); (Just(a), b) }) } @@ -269,7 +269,7 @@ fn csc_invertible_diagonal() -> impl Strategy> { fn csc_square_with_non_zero_diagonals() -> impl Strategy> { csc_invertible_diagonal() .prop_flat_map(|d| { - csc(value_strategy::(), Just(d.nrows()), Just(d.nrows()), PROPTEST_MAX_NNZ) + csc(value_strategy::(), d.nrows(), d.nrows(), PROPTEST_MAX_NNZ) .prop_map(move |mut c| { for (i, j, v) in c.triplet_iter_mut() { if i == j { @@ -412,7 +412,7 @@ proptest! { (a, b) in csr_strategy() .prop_flat_map(|a| { - let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ); + let b = csr(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ); (Just(a), b) })) { @@ -448,7 +448,7 @@ proptest! { (a, b) in csr_strategy() .prop_flat_map(|a| { - let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ); + let b = csr(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ); (Just(a), b) })) { @@ -606,7 +606,7 @@ proptest! { .prop_flat_map(|a| { let max_nnz = PROPTEST_MAX_NNZ; let cols = PROPTEST_MATRIX_DIM; - let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.ncols()), cols, max_nnz); + let b = csr(PROPTEST_I32_VALUE_STRATEGY, a.ncols(), cols, max_nnz); (Just(a), b) })) { @@ -713,7 +713,7 @@ proptest! { .prop_flat_map(|a| { let max_nnz = PROPTEST_MAX_NNZ; let cols = PROPTEST_MATRIX_DIM; - let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.ncols()), cols, max_nnz); + let b = csc(PROPTEST_I32_VALUE_STRATEGY, a.ncols(), cols, max_nnz); (Just(a), b) }) .prop_map(|(a, b)| { @@ -865,7 +865,7 @@ proptest! { (a, b) in csc_strategy() .prop_flat_map(|a| { - let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ); + let b = csc(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ); (Just(a), b) })) { @@ -901,7 +901,7 @@ proptest! { (a, b) in csc_strategy() .prop_flat_map(|a| { - let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ); + let b = csc(PROPTEST_I32_VALUE_STRATEGY, a.nrows(), a.ncols(), PROPTEST_MAX_NNZ); (Just(a), b) })) {