Improve and test proptest generators
Due to a bug in proptest, we were required to pull in and modify parts of proptest::strategy::Shuffle. Once the below PR has been merged and released on crates.io, we can remove this code. https://github.com/AltSysrq/proptest/pull/217
This commit is contained in:
parent
3eab45d81b
commit
9cd1540496
|
@ -2,6 +2,10 @@
|
||||||
//!
|
//!
|
||||||
//! TODO: Clarify that this module needs proptest-support feature
|
//! TODO: Clarify that this module needs proptest-support feature
|
||||||
|
|
||||||
|
// Contains some patched code from proptest that we can remove in the (hopefully near) future.
|
||||||
|
// See docs in file for more details.
|
||||||
|
mod proptest_patched;
|
||||||
|
|
||||||
use crate::coo::CooMatrix;
|
use crate::coo::CooMatrix;
|
||||||
use proptest::prelude::*;
|
use proptest::prelude::*;
|
||||||
use proptest::collection::{vec, hash_map, btree_set};
|
use proptest::collection::{vec, hash_map, btree_set};
|
||||||
|
@ -16,12 +20,20 @@ use crate::csc::CscMatrix;
|
||||||
fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
|
fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
|
||||||
-> impl Strategy<Value=Vec<(usize, usize)>>
|
-> impl Strategy<Value=Vec<(usize, usize)>>
|
||||||
{
|
{
|
||||||
|
assert!(nnz <= nrows * ncols);
|
||||||
let mut booleans = vec![true; nnz];
|
let mut booleans = vec![true; nnz];
|
||||||
booleans.append(&mut vec![false; (nrows * ncols) - nnz]);
|
booleans.append(&mut vec![false; (nrows * ncols) - nnz]);
|
||||||
// Make sure that exactly `nnz` of the booleans are true
|
// Make sure that exactly `nnz` of the booleans are true
|
||||||
Just(booleans)
|
|
||||||
// Need to shuffle to make sure they are randomly distributed
|
// TODO: We cannot use the below code because of a bug in proptest, see
|
||||||
.prop_shuffle()
|
// https://github.com/AltSysrq/proptest/pull/217
|
||||||
|
// so for now we're using a patched version of the Shuffle adapter
|
||||||
|
// (see also docs in `proptest_patched`
|
||||||
|
// Just(booleans)
|
||||||
|
// // Need to shuffle to make sure they are randomly distributed
|
||||||
|
// .prop_shuffle()
|
||||||
|
|
||||||
|
proptest_patched::Shuffle(Just(booleans))
|
||||||
.prop_map(move |booleans| {
|
.prop_map(move |booleans| {
|
||||||
booleans
|
booleans
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -265,8 +277,8 @@ pub fn sparsity_pattern(
|
||||||
// If the required number of nonzeros is sufficiently dense,
|
// If the required number of nonzeros is sufficiently dense,
|
||||||
// we instead use a dense sampling
|
// we instead use a dense sampling
|
||||||
dense_row_major_coord_strategy(nmajor, nminor, nnz)
|
dense_row_major_coord_strategy(nmajor, nminor, nnz)
|
||||||
.prop_map(move |triplets| {
|
.prop_map(move |coords| {
|
||||||
let coords = triplets.into_iter();
|
let coords = coords.into_iter();
|
||||||
sparsity_pattern_from_row_major_coords(nmajor, nminor, coords)
|
sparsity_pattern_from_row_major_coords(nmajor, nminor, coords)
|
||||||
}).boxed()
|
}).boxed()
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,146 @@
|
||||||
|
//! Contains a modified implementation of `proptest::strategy::Shuffle`.
|
||||||
|
//!
|
||||||
|
//! The current implementation in `proptest` does not generate all permutations, which is
|
||||||
|
//! problematic for our proptest generators. The issue has been fixed in
|
||||||
|
//! https://github.com/AltSysrq/proptest/pull/217
|
||||||
|
//! but it has yet to be merged and released. As soon as this fix makes it into a new release,
|
||||||
|
//! the modified code here can be removed.
|
||||||
|
//!
|
||||||
|
/*!
|
||||||
|
This code has been copied and adapted from
|
||||||
|
https://github.com/AltSysrq/proptest/blob/master/proptest/src/strategy/shuffle.rs
|
||||||
|
The original licensing text is:
|
||||||
|
|
||||||
|
//-
|
||||||
|
// Copyright 2017 Jason Lingle
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
|
||||||
|
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
|
||||||
|
// option. This file may not be copied, modified, or distributed
|
||||||
|
// except according to those terms.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
use proptest::strategy::{Strategy, Shuffleable, NewTree, ValueTree};
|
||||||
|
use proptest::test_runner::{TestRunner, TestRng};
|
||||||
|
use std::cell::Cell;
|
||||||
|
use proptest::num;
|
||||||
|
use proptest::prelude::Rng;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
#[must_use = "strategies do nothing unless used"]
|
||||||
|
pub struct Shuffle<S>(pub(super) S);
|
||||||
|
|
||||||
|
impl<S: Strategy> Strategy for Shuffle<S>
|
||||||
|
where
|
||||||
|
S::Value: Shuffleable,
|
||||||
|
{
|
||||||
|
type Tree = ShuffleValueTree<S::Tree>;
|
||||||
|
type Value = S::Value;
|
||||||
|
|
||||||
|
fn new_tree(&self, runner: &mut TestRunner) -> NewTree<Self> {
|
||||||
|
let rng = runner.new_rng();
|
||||||
|
|
||||||
|
self.0.new_tree(runner).map(|inner| ShuffleValueTree {
|
||||||
|
inner,
|
||||||
|
rng,
|
||||||
|
dist: Cell::new(None),
|
||||||
|
simplifying_inner: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ShuffleValueTree<V> {
|
||||||
|
inner: V,
|
||||||
|
rng: TestRng,
|
||||||
|
dist: Cell<Option<num::usize::BinarySearch>>,
|
||||||
|
simplifying_inner: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V: ValueTree> ShuffleValueTree<V>
|
||||||
|
where
|
||||||
|
V::Value: Shuffleable,
|
||||||
|
{
|
||||||
|
fn init_dist(&self, dflt: usize) -> usize {
|
||||||
|
if self.dist.get().is_none() {
|
||||||
|
self.dist.set(Some(num::usize::BinarySearch::new(dflt)));
|
||||||
|
}
|
||||||
|
|
||||||
|
self.dist.get().unwrap().current()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn force_init_dist(&self) {
|
||||||
|
if self.dist.get().is_none() {
|
||||||
|
let _ = self.init_dist(self.current().shuffle_len());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V: ValueTree> ValueTree for ShuffleValueTree<V>
|
||||||
|
where
|
||||||
|
V::Value: Shuffleable,
|
||||||
|
{
|
||||||
|
type Value = V::Value;
|
||||||
|
|
||||||
|
fn current(&self) -> V::Value {
|
||||||
|
let mut value = self.inner.current();
|
||||||
|
let len = value.shuffle_len();
|
||||||
|
// The maximum distance to swap elements. This could be larger than
|
||||||
|
// `value` if `value` has reduced size during shrinking; that's OK,
|
||||||
|
// since we only use this to filter swaps.
|
||||||
|
let max_swap = self.init_dist(len);
|
||||||
|
|
||||||
|
// If empty collection or all swaps will be filtered out, there's
|
||||||
|
// nothing to shuffle.
|
||||||
|
if 0 == len || 0 == max_swap {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut rng = self.rng.clone();
|
||||||
|
|
||||||
|
for start_index in 0..len - 1 {
|
||||||
|
// Determine the other index to be swapped, then skip the swap if
|
||||||
|
// it is too far. This ordering is critical, as it ensures that we
|
||||||
|
// generate the same sequence of random numbers every time.
|
||||||
|
|
||||||
|
// NOTE: The below line is the whole reason for the existence of this adapted code
|
||||||
|
// We need to be able to swap with the same element, so that some elements remain in
|
||||||
|
// place rather being swapped
|
||||||
|
// let end_index = rng.gen_range(start_index + 1, len);
|
||||||
|
let end_index = rng.gen_range(start_index, len);
|
||||||
|
if end_index - start_index <= max_swap {
|
||||||
|
value.shuffle_swap(start_index, end_index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
value
|
||||||
|
}
|
||||||
|
|
||||||
|
fn simplify(&mut self) -> bool {
|
||||||
|
if self.simplifying_inner {
|
||||||
|
self.inner.simplify()
|
||||||
|
} else {
|
||||||
|
// Ensure that we've initialised `dist` to *something* to give
|
||||||
|
// consistent non-panicking behaviour even if called in an
|
||||||
|
// unexpected sequence.
|
||||||
|
self.force_init_dist();
|
||||||
|
if self.dist.get_mut().as_mut().unwrap().simplify() {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
self.simplifying_inner = true;
|
||||||
|
self.inner.simplify()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn complicate(&mut self) -> bool {
|
||||||
|
if self.simplifying_inner {
|
||||||
|
self.inner.complicate()
|
||||||
|
} else {
|
||||||
|
self.force_init_dist();
|
||||||
|
self.dist.get_mut().as_mut().unwrap().complicate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -6,7 +6,7 @@ fn coo_no_duplicates_generates_admissible_matrices() {
|
||||||
|
|
||||||
#[cfg(feature = "slow-tests")]
|
#[cfg(feature = "slow-tests")]
|
||||||
mod slow {
|
mod slow {
|
||||||
use nalgebra_sparse::proptest::{coo_with_duplicates, coo_no_duplicates};
|
use nalgebra_sparse::proptest::{coo_with_duplicates, coo_no_duplicates, csr, csc, sparsity_pattern};
|
||||||
use nalgebra::DMatrix;
|
use nalgebra::DMatrix;
|
||||||
|
|
||||||
use proptest::test_runner::TestRunner;
|
use proptest::test_runner::TestRunner;
|
||||||
|
@ -18,6 +18,7 @@ mod slow {
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::iter::repeat;
|
use std::iter::repeat;
|
||||||
use std::ops::RangeInclusive;
|
use std::ops::RangeInclusive;
|
||||||
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
|
||||||
fn generate_all_possible_matrices(value_range: RangeInclusive<i32>,
|
fn generate_all_possible_matrices(value_range: RangeInclusive<i32>,
|
||||||
rows_range: RangeInclusive<usize>,
|
rows_range: RangeInclusive<usize>,
|
||||||
|
@ -73,19 +74,15 @@ mod slow {
|
||||||
let values = -1..=1;
|
let values = -1..=1;
|
||||||
let rows = 0..=2;
|
let rows = 0..=2;
|
||||||
let cols = 0..=3;
|
let cols = 0..=3;
|
||||||
let strategy = coo_no_duplicates(values.clone(), rows.clone(), cols.clone(), 2 * 3);
|
let max_nnz = rows.end() * cols.end();
|
||||||
|
let strategy = coo_no_duplicates(values.clone(), rows.clone(), cols.clone(), max_nnz);
|
||||||
|
|
||||||
// Enumerate all possible combinations
|
// Enumerate all possible combinations
|
||||||
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||||
|
|
||||||
let mut visited_combinations = HashSet::new();
|
let visited_combinations = sample_matrix_output_space(strategy,
|
||||||
for _ in 0..num_generated_matrices {
|
&mut runner,
|
||||||
let tree = strategy
|
num_generated_matrices);
|
||||||
.new_tree(&mut runner)
|
|
||||||
.expect("Tree generation should not fail");
|
|
||||||
let matrix = tree.current();
|
|
||||||
visited_combinations.insert(DMatrix::from(&matrix));
|
|
||||||
}
|
|
||||||
|
|
||||||
assert_eq!(visited_combinations.len(), all_combinations.len());
|
assert_eq!(visited_combinations.len(), all_combinations.len());
|
||||||
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values.");
|
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values.");
|
||||||
|
@ -108,21 +105,17 @@ mod slow {
|
||||||
let values = -1..=1;
|
let values = -1..=1;
|
||||||
let rows = 0..=2;
|
let rows = 0..=2;
|
||||||
let cols = 0..=3;
|
let cols = 0..=3;
|
||||||
let strategy = coo_with_duplicates(values.clone(), rows.clone(), cols.clone(), 2 * 3, 2);
|
let max_nnz = rows.end() * cols.end();
|
||||||
|
let strategy = coo_with_duplicates(values.clone(), rows.clone(), cols.clone(), max_nnz, 2);
|
||||||
|
|
||||||
// Enumerate all possible combinations that fit the constraints
|
// Enumerate all possible combinations that fit the constraints
|
||||||
// (note: this is only a subset of the matrices that can be generated by
|
// (note: this is only a subset of the matrices that can be generated by
|
||||||
// `coo_with_duplicates`)
|
// `coo_with_duplicates`)
|
||||||
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||||
|
|
||||||
let mut visited_combinations = HashSet::new();
|
let visited_combinations = sample_matrix_output_space(strategy,
|
||||||
for _ in 0..num_generated_matrices {
|
&mut runner,
|
||||||
let tree = strategy
|
num_generated_matrices);
|
||||||
.new_tree(&mut runner)
|
|
||||||
.expect("Tree generation should not fail");
|
|
||||||
let matrix = tree.current();
|
|
||||||
visited_combinations.insert(DMatrix::from(&matrix));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Here we cannot verify that the set of visited combinations is *equal* to
|
// Here we cannot verify that the set of visited combinations is *equal* to
|
||||||
// all possible outcomes with the given constraints, however the
|
// all possible outcomes with the given constraints, however the
|
||||||
|
@ -131,6 +124,110 @@ mod slow {
|
||||||
// is contained in the set of visited matrices
|
// is contained in the set of visited matrices
|
||||||
assert!(all_combinations.is_subset(&visited_combinations));
|
assert!(all_combinations.is_subset(&visited_combinations));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Tests for csr, csc and sparsity_pattern strategies
|
#[cfg(feature = "slow-tests")]
|
||||||
|
#[test]
|
||||||
|
fn csr_samples_all_admissible_outputs() {
|
||||||
|
// We use a deterministic test runner to make the test "stable".
|
||||||
|
let mut runner = TestRunner::deterministic();
|
||||||
|
|
||||||
|
// This number needs to be high enough so that we with high probability sample
|
||||||
|
// all possible cases
|
||||||
|
let num_generated_matrices = 500000;
|
||||||
|
|
||||||
|
let values = -1..=1;
|
||||||
|
let rows = 0..=2;
|
||||||
|
let cols = 0..=3;
|
||||||
|
let max_nnz = rows.end() * cols.end();
|
||||||
|
let strategy = csr(values.clone(), rows.clone(), cols.clone(), max_nnz);
|
||||||
|
|
||||||
|
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||||
|
|
||||||
|
let visited_combinations = sample_matrix_output_space(strategy,
|
||||||
|
&mut runner,
|
||||||
|
num_generated_matrices);
|
||||||
|
|
||||||
|
assert_eq!(visited_combinations.len(), all_combinations.len());
|
||||||
|
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values.");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "slow-tests")]
|
||||||
|
#[test]
|
||||||
|
fn csc_samples_all_admissible_outputs() {
|
||||||
|
// We use a deterministic test runner to make the test "stable".
|
||||||
|
let mut runner = TestRunner::deterministic();
|
||||||
|
|
||||||
|
// This number needs to be high enough so that we with high probability sample
|
||||||
|
// all possible cases
|
||||||
|
let num_generated_matrices = 500000;
|
||||||
|
|
||||||
|
let values = -1..=1;
|
||||||
|
let rows = 0..=2;
|
||||||
|
let cols = 0..=3;
|
||||||
|
let max_nnz = rows.end() * cols.end();
|
||||||
|
let strategy = csc(values.clone(), rows.clone(), cols.clone(), max_nnz);
|
||||||
|
|
||||||
|
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||||
|
|
||||||
|
let visited_combinations = sample_matrix_output_space(strategy,
|
||||||
|
&mut runner,
|
||||||
|
num_generated_matrices);
|
||||||
|
|
||||||
|
assert_eq!(visited_combinations.len(), all_combinations.len());
|
||||||
|
assert_eq!(visited_combinations, all_combinations, "Did not sample all possible values.");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "slow-tests")]
|
||||||
|
#[test]
|
||||||
|
fn sparsity_pattern_samples_all_admissible_outputs() {
|
||||||
|
let mut runner = TestRunner::deterministic();
|
||||||
|
|
||||||
|
let num_generated_patterns = 50000;
|
||||||
|
|
||||||
|
let major_dims = 0..=2;
|
||||||
|
let minor_dims = 0..=3;
|
||||||
|
let max_nnz = major_dims.end() * minor_dims.end();
|
||||||
|
let strategy = sparsity_pattern(major_dims.clone(), minor_dims.clone(), max_nnz);
|
||||||
|
|
||||||
|
let visited_patterns: HashSet<_> = sample_strategy(strategy, &mut runner)
|
||||||
|
.take(num_generated_patterns)
|
||||||
|
.map(|pattern| {
|
||||||
|
// We represent patterns as dense matrices with 1 if an entry is occupied,
|
||||||
|
// 0 otherwise
|
||||||
|
let values = vec![1; pattern.nnz()];
|
||||||
|
CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
||||||
|
})
|
||||||
|
.map(|csr| DMatrix::from(&csr))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let all_possible_patterns = generate_all_possible_matrices(0..=1, major_dims, minor_dims);
|
||||||
|
|
||||||
|
assert_eq!(visited_patterns.len(), all_possible_patterns.len());
|
||||||
|
assert_eq!(visited_patterns, all_possible_patterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample_matrix_output_space<S>(strategy: S,
|
||||||
|
runner: &mut TestRunner,
|
||||||
|
num_samples: usize)
|
||||||
|
-> HashSet<DMatrix<i32>>
|
||||||
|
where
|
||||||
|
S: Strategy,
|
||||||
|
DMatrix<i32>: for<'b> From<&'b S::Value>
|
||||||
|
{
|
||||||
|
sample_strategy(strategy, runner)
|
||||||
|
.take(num_samples)
|
||||||
|
.map(|matrix| DMatrix::from(&matrix))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample_strategy<'a, S: 'a + Strategy>(strategy: S, runner: &'a mut TestRunner)
|
||||||
|
-> impl 'a + Iterator<Item=S::Value> {
|
||||||
|
repeat(()).map(move |_| {
|
||||||
|
let tree = strategy
|
||||||
|
.new_tree(runner)
|
||||||
|
.expect("Tree generation should not fail");
|
||||||
|
let value = tree.current();
|
||||||
|
value
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue