diff --git a/src/distribution/empirical.rs b/src/distribution/empirical.rs index 21832b3c..c1d8bc71 100644 --- a/src/distribution/empirical.rs +++ b/src/distribution/empirical.rs @@ -1,22 +1,45 @@ use crate::distribution::ContinuousCDF; use crate::statistics::*; -use core::cmp::Ordering; -use std::collections::BTreeMap; +use non_nan::NonNan; +use std::collections::btree_map::{BTreeMap, Entry}; +use std::convert::Infallible; +use std::ops::Bound; -#[derive(Clone, PartialEq, Debug)] -struct NonNan(T); +mod non_nan { + use core::cmp::Ordering; -impl Eq for NonNan {} + #[derive(Clone, Copy, PartialEq, Debug)] + pub struct NonNan(T); -impl PartialOrd for NonNan { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) + impl NonNan { + pub fn get(self) -> T { + self.0 + } } -} -impl Ord for NonNan { - fn cmp(&self, other: &Self) -> Ordering { - self.0.partial_cmp(&other.0).unwrap() + impl NonNan { + #[inline] + pub fn new(x: f64) -> Option { + if x.is_nan() { + None + } else { + Some(Self(x)) + } + } + } + + impl Eq for NonNan {} + + impl PartialOrd for NonNan { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl Ord for NonNan { + fn cmp(&self, other: &Self) -> Ordering { + self.0.partial_cmp(&other.0).unwrap() + } } } @@ -36,10 +59,15 @@ impl Ord for NonNan { /// ``` #[derive(Clone, PartialEq, Debug)] pub struct Empirical { - sum: f64, - mean_and_var: Option<(f64, f64)>, // keys are data points, values are number of data points with equal value data: BTreeMap, u64>, + + // The following fields are only logically valid if !data.is_empty(): + /// Total amount of data points (== sum of all _values_ inside self.data). + /// Must be 0 iff data.is_empty() + sum: u64, + mean: f64, + var: f64, } impl Empirical { @@ -56,54 +84,62 @@ impl Empirical { /// let mut result = Empirical::new(); /// assert!(result.is_ok()); /// ``` - #[allow(clippy::result_unit_err)] - pub fn new() -> Result { + pub fn new() -> Result { Ok(Empirical { - sum: 0., - mean_and_var: None, data: BTreeMap::new(), + sum: 0, + mean: 0.0, + var: 0.0, }) } pub fn add(&mut self, data_point: f64) { - if !data_point.is_nan() { - self.sum += 1.; - match self.mean_and_var { - Some((mean, var)) => { - let sum = self.sum; - let var = var + (sum - 1.) * (data_point - mean) * (data_point - mean) / sum; - let mean = mean + (data_point - mean) / sum; - self.mean_and_var = Some((mean, var)); - } - None => { - self.mean_and_var = Some((data_point, 0.)); - } - } - *self.data.entry(NonNan(data_point)).or_insert(0) += 1; - } + let map_key = match NonNan::new(data_point) { + Some(valid) => valid, + None => return, + }; + + self.sum += 1; + let sum = self.sum as f64; + self.var += (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum; + self.mean += (data_point - self.mean) / sum; + + self.data + .entry(map_key) + .and_modify(|c| *c += 1) + .or_insert(1); } pub fn remove(&mut self, data_point: f64) { - if !data_point.is_nan() { - if let (Some(val), Some((mean, var))) = - (self.data.remove(&NonNan(data_point)), self.mean_and_var) - { - if val == 1 && self.data.is_empty() { - self.mean_and_var = None; - self.sum = 0.; - return; - }; - // reset mean and var - let mean = (self.sum * mean - data_point) / (self.sum - 1.); - let var = - var - (self.sum - 1.) * (data_point - mean) * (data_point - mean) / self.sum; - self.sum -= 1.; - if val != 1 { - self.data.insert(NonNan(data_point), val - 1); - }; - self.mean_and_var = Some((mean, var)); + let map_key = match NonNan::new(data_point) { + Some(valid) => valid, + None => return, + }; + + let mut entry = match self.data.entry(map_key) { + Entry::Occupied(entry) => entry, + Entry::Vacant(_) => return, // no entry found + }; + + if *entry.get() == 1 { + entry.remove(); + if self.data.is_empty() { + // logically, this should not need special handling. + // FP math can result in mean or var being != 0.0 though. + self.sum = 0; + self.mean = 0.0; + self.var = 0.0; + return; } + } else { + *entry.get_mut() -= 1; } + + // reset mean and var + let sum = self.sum as f64; + self.mean = (sum * self.mean - data_point) / (sum - 1.); + self.var -= (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum; + self.sum -= 1; } // Due to issues with rounding and floating-point accuracy the default @@ -148,7 +184,7 @@ impl std::fmt::Display for Empirical { let mut enumerated_values = self .data .iter() - .flat_map(|(&NonNan(x), &count)| std::iter::repeat(x).take(count as usize)); + .flat_map(|(x, &count)| std::iter::repeat(x.get()).take(count as usize)); if let Some(x) = enumerated_values.next() { write!(f, "Empirical([{x:.3e}")?; @@ -190,48 +226,50 @@ impl ::rand::distributions::Distribution for Empirical { /// Panics if number of samples is zero impl Max for Empirical { fn max(&self) -> f64 { - self.data.keys().rev().map(|key| key.0).next().unwrap() + self.data.keys().rev().map(|key| key.get()).next().unwrap() } } /// Panics if number of samples is zero impl Min for Empirical { fn min(&self) -> f64 { - self.data.keys().map(|key| key.0).next().unwrap() + self.data.keys().map(|key| key.get()).next().unwrap() } } impl Distribution for Empirical { fn mean(&self) -> Option { - self.mean_and_var.map(|(mean, _)| mean) + if self.data.is_empty() { + None + } else { + Some(self.mean) + } } fn variance(&self) -> Option { - self.mean_and_var.map(|(_, var)| var / (self.sum - 1.)) + if self.data.is_empty() { + None + } else { + Some(self.var / (self.sum as f64 - 1.)) + } } } impl ContinuousCDF for Empirical { fn cdf(&self, x: f64) -> f64 { - let mut sum = 0; - for (keys, values) in &self.data { - if keys.0 > x { - return sum as f64 / self.sum; - } - sum += values; - } - sum as f64 / self.sum + let start = Bound::Unbounded; + let end = Bound::Included(NonNan::new(x).expect("x must not be NaN")); + + let sum: u64 = self.data.range((start, end)).map(|(_, v)| v).sum(); + sum as f64 / self.sum as f64 } fn sf(&self, x: f64) -> f64 { - let mut sum = 0; - for (keys, values) in self.data.iter().rev() { - if keys.0 <= x { - return sum as f64 / self.sum; - } - sum += values; - } - sum as f64 / self.sum + let start = Bound::Excluded(NonNan::new(x).expect("x must not be NaN")); + let end = Bound::Unbounded; + + let sum: u64 = self.data.range((start, end)).map(|(_, v)| v).sum(); + sum as f64 / self.sum as f64 } fn inverse_cdf(&self, p: f64) -> f64 { @@ -242,6 +280,78 @@ impl ContinuousCDF for Empirical { #[cfg(test)] mod tests { use super::*; + + #[test] + fn test_add_nan() { + let mut empirical = Empirical::new().unwrap(); + + // should not panic + empirical.add(f64::NAN); + } + + #[test] + fn test_remove_nan() { + let mut empirical = Empirical::new().unwrap(); + + empirical.add(5.2); + // should not panic + empirical.remove(f64::NAN); + } + + #[test] + fn test_remove_nonexisting() { + let mut empirical = Empirical::new().unwrap(); + + empirical.add(5.2); + // should not panic + empirical.remove(10.0); + } + + #[test] + fn test_remove_all() { + let mut empirical = Empirical::new().unwrap(); + + empirical.add(17.123); + empirical.add(-10.0); + empirical.add(0.0); + empirical.remove(-10.0); + empirical.remove(17.123); + empirical.remove(0.0); + + assert!(empirical.mean().is_none()); + assert!(empirical.variance().is_none()); + } + + #[test] + fn test_mean() { + fn test_mean_for_samples(expected_mean: f64, samples: Vec) { + let dist = Empirical::from_iter(samples); + assert_relative_eq!(dist.mean().unwrap(), expected_mean); + } + + let dist = Empirical::from_iter(vec![]); + assert!(dist.mean().is_none()); + + test_mean_for_samples(4.0, vec![4.0; 100]); + test_mean_for_samples(-0.2, vec![-0.2; 100]); + test_mean_for_samples(28.5, vec![21.3, 38.4, 12.7, 41.6]); + } + + #[test] + fn test_var() { + fn test_var_for_samples(expected_var: f64, samples: Vec) { + let dist = Empirical::from_iter(samples); + assert_relative_eq!(dist.variance().unwrap(), expected_var); + } + + let dist = Empirical::from_iter(vec![]); + assert!(dist.variance().is_none()); + + test_var_for_samples(0.0, vec![4.0; 100]); + test_var_for_samples(0.0, vec![-0.2; 100]); + test_var_for_samples(190.36666666666667, vec![21.3, 38.4, 12.7, 41.6]); + } + #[test] fn test_cdf() { let samples = vec![5.0, 10.0];