From 75405b1e24ed71561b3fb9e210dfab3f4c3e0d39 Mon Sep 17 00:00:00 2001 From: vasil <vaskonikolov2003@gmail.com> Date: Tue, 25 Apr 2023 01:25:36 +0300 Subject: [PATCH] fix bug, add test in tests folder --- src/base/mod.rs | 2 -- src/base/statistics.rs | 12 ++++-------- tests/core/mod.rs | 1 + src/base/variance_test.rs => tests/core/variance.rs | 6 +++--- 4 files changed, 8 insertions(+), 13 deletions(-) rename src/base/variance_test.rs => tests/core/variance.rs (72%) diff --git a/src/base/mod.rs b/src/base/mod.rs index 1eabbfcf..0f09cc33 100644 --- a/src/base/mod.rs +++ b/src/base/mod.rs @@ -64,5 +64,3 @@ pub use self::matrix_view::*; pub use self::storage::*; #[cfg(any(feature = "std", feature = "alloc"))] pub use self::vec_storage::*; - -mod variance_test; diff --git a/src/base/statistics.rs b/src/base/statistics.rs index ebefb49d..6007f8c7 100644 --- a/src/base/statistics.rs +++ b/src/base/statistics.rs @@ -335,16 +335,12 @@ impl<T: Scalar, R: Dim, C: Dim, S: RawStorage<T, R, C>> Matrix<T, R, C, S> { if self.is_empty() { T::zero() } else { - // cannot use sum since `T` is not `Sum` by trait bounds - let sum_of_elements = self.iter().cloned().fold(T::zero(), |a, b| a + b); - let n_elements = crate::convert::<_, T>(self.len() as f64); - let mean = sum_of_elements / n_elements.clone(); + let n_elements: T = crate::convert(self.len() as f64); + let mean = self.mean(); - let variance = self.iter().cloned().fold(T::zero(), |acc, x| { + self.iter().cloned().fold(T::zero(), |acc, x| { acc + (x.clone() - mean.clone()) * (x.clone() - mean.clone()) - }) / n_elements; - - variance + }) / n_elements } } diff --git a/tests/core/mod.rs b/tests/core/mod.rs index 0f7ee85b..f0484e4d 100644 --- a/tests/core/mod.rs +++ b/tests/core/mod.rs @@ -11,6 +11,7 @@ mod reshape; #[cfg(feature = "rkyv-serialize-no-std")] mod rkyv; mod serde; +mod variance; #[cfg(feature = "compare")] mod matrixcompare; diff --git a/src/base/variance_test.rs b/tests/core/variance.rs similarity index 72% rename from src/base/variance_test.rs rename to tests/core/variance.rs index 4319e156..c643ea3f 100644 --- a/src/base/variance_test.rs +++ b/tests/core/variance.rs @@ -1,10 +1,10 @@ -#[cfg(test)] -use crate::DVector; +use nalgebra::DVector; #[test] fn test_variance_new() { let long_repeating_vector = DVector::repeat(10_000, 100000000.0); assert_eq!(long_repeating_vector.variance(), 0.0); let short_vec = DVector::from_vec(vec![1., 2., 3.]); - assert_eq!(short_vec.variance(), 2.0 / 3.0) + + assert_eq!(short_vec.variance(), 2.0 / 3.0); }