diff --git a/CHANGELOG.md b/CHANGELOG.md index cf0253aa..971c5173 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ documented here. This project adheres to [Semantic Versioning](https://semver.org/). +## Unreleased + +### Fixed +- Fixed severe catastrophic cancellation issue in variance calculation. ## [0.32.2] (07 March 2023) diff --git a/Cargo.toml b/Cargo.toml index 1d36aeb1..f5c67d15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -111,6 +111,7 @@ serde_json = "1.0" rand_xorshift = "0.3" rand_isaac = "0.3" criterion = { version = "0.4", features = ["html_reports"] } +nalgebra = { path = ".", features = ["debug", "compare", "rand", "macros"]} # For matrix comparison macro matrixcompare = "0.3.0" diff --git a/src/base/statistics.rs b/src/base/statistics.rs index 9f0e0ee6..6007f8c7 100644 --- a/src/base/statistics.rs +++ b/src/base/statistics.rs @@ -335,12 +335,12 @@ impl> Matrix { if self.is_empty() { T::zero() } else { - let val = self.iter().cloned().fold((T::zero(), T::zero()), |a, b| { - (a.0 + b.clone() * b.clone(), a.1 + b) - }); - let denom = T::one() / crate::convert::<_, T>(self.len() as f64); - let vd = val.1 * denom.clone(); - val.0 * denom - vd.clone() * vd + let n_elements: T = crate::convert(self.len() as f64); + let mean = self.mean(); + + self.iter().cloned().fold(T::zero(), |acc, x| { + acc + (x.clone() - mean.clone()) * (x.clone() - mean.clone()) + }) / n_elements } } diff --git a/tests/core/mod.rs b/tests/core/mod.rs index 0f7ee85b..f0484e4d 100644 --- a/tests/core/mod.rs +++ b/tests/core/mod.rs @@ -11,6 +11,7 @@ mod reshape; #[cfg(feature = "rkyv-serialize-no-std")] mod rkyv; mod serde; +mod variance; #[cfg(feature = "compare")] mod matrixcompare; diff --git a/tests/core/variance.rs b/tests/core/variance.rs new file mode 100644 index 00000000..eb08ea0f --- /dev/null +++ b/tests/core/variance.rs @@ -0,0 +1,18 @@ +use nalgebra::DVector; + +#[test] +fn test_variance_catastrophic_cancellation() { + let long_repeating_vector = DVector::repeat(10_000, 100000000.0); + assert_eq!(long_repeating_vector.variance(), 0.0); + + let short_vec = DVector::from_vec(vec![1., 2., 3.]); + assert_eq!(short_vec.variance(), 2.0 / 3.0); + + let short_vec = + DVector::::from_vec(vec![1.0e8 + 4.0, 1.0e8 + 7.0, 1.0e8 + 13.0, 1.0e8 + 16.0]); + assert_eq!(short_vec.variance(), 22.5); + + let short_vec = + DVector::::from_vec(vec![1.0e9 + 4.0, 1.0e9 + 7.0, 1.0e9 + 13.0, 1.0e9 + 16.0]); + assert_eq!(short_vec.variance(), 22.5); +}