fix bug, add test in tests folder

This commit is contained in:
vasil 2023-04-25 01:25:36 +03:00
parent fc56abe481
commit 75405b1e24
4 changed files with 8 additions and 13 deletions

View File

@ -64,5 +64,3 @@ pub use self::matrix_view::*;
pub use self::storage::*; pub use self::storage::*;
#[cfg(any(feature = "std", feature = "alloc"))] #[cfg(any(feature = "std", feature = "alloc"))]
pub use self::vec_storage::*; pub use self::vec_storage::*;
mod variance_test;

View File

@ -335,16 +335,12 @@ impl<T: Scalar, R: Dim, C: Dim, S: RawStorage<T, R, C>> Matrix<T, R, C, S> {
if self.is_empty() { if self.is_empty() {
T::zero() T::zero()
} else { } else {
// cannot use sum since `T` is not `Sum` by trait bounds let n_elements: T = crate::convert(self.len() as f64);
let sum_of_elements = self.iter().cloned().fold(T::zero(), |a, b| a + b); let mean = self.mean();
let n_elements = crate::convert::<_, T>(self.len() as f64);
let mean = sum_of_elements / n_elements.clone();
let variance = self.iter().cloned().fold(T::zero(), |acc, x| { self.iter().cloned().fold(T::zero(), |acc, x| {
acc + (x.clone() - mean.clone()) * (x.clone() - mean.clone()) acc + (x.clone() - mean.clone()) * (x.clone() - mean.clone())
}) / n_elements; }) / n_elements
variance
} }
} }

View File

@ -11,6 +11,7 @@ mod reshape;
#[cfg(feature = "rkyv-serialize-no-std")] #[cfg(feature = "rkyv-serialize-no-std")]
mod rkyv; mod rkyv;
mod serde; mod serde;
mod variance;
#[cfg(feature = "compare")] #[cfg(feature = "compare")]
mod matrixcompare; mod matrixcompare;

View File

@ -1,10 +1,10 @@
#[cfg(test)] use nalgebra::DVector;
use crate::DVector;
#[test] #[test]
fn test_variance_new() { fn test_variance_new() {
let long_repeating_vector = DVector::repeat(10_000, 100000000.0); let long_repeating_vector = DVector::repeat(10_000, 100000000.0);
assert_eq!(long_repeating_vector.variance(), 0.0); assert_eq!(long_repeating_vector.variance(), 0.0);
let short_vec = DVector::from_vec(vec![1., 2., 3.]); let short_vec = DVector::from_vec(vec![1., 2., 3.]);
assert_eq!(short_vec.variance(), 2.0 / 3.0)
assert_eq!(short_vec.variance(), 2.0 / 3.0);
} }