From b4e8668eddfdcb922ffe9388e84ea3293fd30552 Mon Sep 17 00:00:00 2001 From: Sasha Syrotenko Date: Wed, 8 Jan 2025 18:44:35 +0200 Subject: [PATCH 1/2] StatisticsV2: initial definition and validation method implementation --- datafusion/expr-common/Cargo.toml | 1 + datafusion/expr-common/src/lib.rs | 4 +- datafusion/expr-common/src/stats.rs | 265 ++++++++++++++++++++++++++++ 3 files changed, 269 insertions(+), 1 deletion(-) create mode 100644 datafusion/expr-common/src/stats.rs diff --git a/datafusion/expr-common/Cargo.toml b/datafusion/expr-common/Cargo.toml index 1ccc6fc17293..6f5f3cb960e7 100644 --- a/datafusion/expr-common/Cargo.toml +++ b/datafusion/expr-common/Cargo.toml @@ -35,6 +35,7 @@ name = "datafusion_expr_common" path = "src/lib.rs" [features] +stats_v2 = [] [dependencies] arrow = { workspace = true } diff --git a/datafusion/expr-common/src/lib.rs b/datafusion/expr-common/src/lib.rs index 179dd75ace85..c4d70e3205c0 100644 --- a/datafusion/expr-common/src/lib.rs +++ b/datafusion/expr-common/src/lib.rs @@ -33,4 +33,6 @@ pub mod interval_arithmetic; pub mod operator; pub mod signature; pub mod sort_properties; -pub mod type_coercion; +#[cfg(feature = "stats_v2")] +pub mod stats; +pub mod type_coercion; \ No newline at end of file diff --git a/datafusion/expr-common/src/stats.rs b/datafusion/expr-common/src/stats.rs new file mode 100644 index 000000000000..baac2c23345f --- /dev/null +++ b/datafusion/expr-common/src/stats.rs @@ -0,0 +1,265 @@ +use crate::interval_arithmetic::Interval; +use crate::stats::StatisticsV2::{Exponential, Gaussian, Unknown}; +use datafusion_common::ScalarValue; + +/// New, enhanced `Statistics` definition, represents three core definitions +pub enum StatisticsV2 { + Uniform { + interval: Interval, + }, + Exponential { + rate: ScalarValue, + offset: ScalarValue, + }, + Gaussian { + mean: ScalarValue, + variance: ScalarValue, + }, + Unknown { + mean: Option, + median: Option, + std_dev: Option, // standard deviation + range: Interval, + }, +} + +impl StatisticsV2 { + //! Validates accumulated statistic for selected distribution methods: + //! - For [`Exponential`], `rate` must be positive; + //! - For [`Gaussian`], `variant` must be non-negative + //! - For [`Unknown`], + //! - if `mean`, `median` are defined, the `range` must contain their values + //! - if `std_dev` is defined, it must be non-negative + pub fn is_valid(&self) -> bool { + match &self { + Exponential { rate, .. } => { + if rate.is_null() { return false; } + let zero = &ScalarValue::new_zero(&rate.data_type()).unwrap(); + rate.gt(zero) + } + Gaussian { variance, .. } => { + if variance.is_null() { return false; } + let zero = &ScalarValue::new_zero(&variance.data_type()).unwrap(); + variance.ge(zero) + } + Unknown { + mean, + median, + std_dev, + range, + } => { + if let (Some(mn), Some(md)) = (mean, median) { + if mn.is_null() || md.is_null() { return false; } + range.contains_value(mn).unwrap() && range.contains_value(md).unwrap() + } else if let Some(dev) = std_dev { + if dev.is_null() { return false; } + dev.gt(&ScalarValue::new_zero(&dev.data_type()).unwrap()) + } else { + false + } + } + _ => true, + } + } +} + +// #[cfg(test)] +#[cfg(all(test, feature = "stats_v2"))] +mod tests { + use crate::interval_arithmetic::Interval; + use crate::stats::StatisticsV2; + use arrow::datatypes::DataType; + use datafusion_common::ScalarValue; + + #[test] + fn uniform_stats_test() { + let uniform_stats = vec![ + ( + StatisticsV2::Uniform { + interval: Interval::make_zero(&DataType::Int8).unwrap(), + }, + true, + ), + ( + StatisticsV2::Uniform { + interval: Interval::make(Some(1), Some(100)).unwrap(), + }, + true, + ), + ( + StatisticsV2::Uniform { + interval: Interval::make(Some(-100), Some(-1)).unwrap(), + }, + true, + ), + ]; + + for case in uniform_stats { + assert_eq!(case.0.is_valid(), case.1); + } + } + + #[test] + fn exponential_stats_test() { + let exp_stats = vec![ + ( + StatisticsV2::Exponential { + rate: ScalarValue::Null, + offset: ScalarValue::Null, + }, + false, + ), + ( + StatisticsV2::Exponential { + rate: ScalarValue::Float32(Some(0.)), + offset: ScalarValue::Float32(Some(1.)), + }, + false, + ), + ( + StatisticsV2::Exponential { + rate: ScalarValue::Float32(Some(100.)), + offset: ScalarValue::Float32(Some(1.)), + }, + true, + ), + ( + StatisticsV2::Exponential { + rate: ScalarValue::Float32(Some(-100.)), + offset: ScalarValue::Float32(Some(1.)), + }, + false, + ), + ]; + for case in exp_stats { + assert_eq!(case.0.is_valid(), case.1); + } + } + + #[test] + fn gaussian_stats_test() { + let gaussian_stats = vec![ + ( + StatisticsV2::Gaussian { + mean: ScalarValue::Null, + variance: ScalarValue::Null, + }, + false, + ), + ( + StatisticsV2::Gaussian { + mean: ScalarValue::Float32(Some(0.)), + variance: ScalarValue::Float32(Some(0.)), + }, + true, + ), + ( + StatisticsV2::Gaussian { + mean: ScalarValue::Float32(Some(0.)), + variance: ScalarValue::Float32(Some(0.5)), + }, + true, + ), + ( + StatisticsV2::Gaussian { + mean: ScalarValue::Float32(Some(0.)), + variance: ScalarValue::Float32(Some(-0.5)), + }, + false, + ), + ]; + for case in gaussian_stats { + assert_eq!(case.0.is_valid(), case.1); + } + } + + #[test] + fn unknown_stats_test() { + let unknown_stats = vec![ + ( + StatisticsV2::Unknown { + mean: None, + median: None, + std_dev: None, + range: Interval::make_zero(&DataType::Float32).unwrap(), + }, + false, + ), + ( + StatisticsV2::Unknown { + mean: Some(ScalarValue::Float32(Some(0.))), + median: None, + std_dev: None, + range: Interval::make_zero(&DataType::Float32).unwrap(), + }, + false, + ), + ( + StatisticsV2::Unknown { + mean: None, + median: Some(ScalarValue::Float32(Some(0.))), + std_dev: None, + range: Interval::make_zero(&DataType::Float32).unwrap(), + }, + false, + ), + ( + StatisticsV2::Unknown { + mean: Some(ScalarValue::Null), + median: Some(ScalarValue::Null), + std_dev: Some(ScalarValue::Null), + range: Interval::make_zero(&DataType::Float32).unwrap(), + }, + false, + ), + ( + StatisticsV2::Unknown { + mean: Some(ScalarValue::Float32(Some(0.))), + median: Some(ScalarValue::Float32(Some(0.))), + std_dev: None, + range: Interval::make_zero(&DataType::Float32).unwrap(), + }, + true, + ), + ( + StatisticsV2::Unknown { + mean: Some(ScalarValue::Float64(Some(50.))), + median: Some(ScalarValue::Float64(Some(50.))), + std_dev: None, + range: Interval::make(Some(0.), Some(100.)).unwrap(), + }, + true, + ), + ( + StatisticsV2::Unknown { + mean: Some(ScalarValue::Float64(Some(50.))), + median: Some(ScalarValue::Float64(Some(50.))), + std_dev: None, + range: Interval::make(Some(-100.), Some(0.)).unwrap(), + }, + false, + ), + ( + StatisticsV2::Unknown { + mean: None, + median: None, + std_dev: Some(ScalarValue::Float64(Some(1.))), + range: Interval::make_zero(&DataType::Float64).unwrap() + }, + true, + ), + ( + StatisticsV2::Unknown { + mean: None, + median: None, + std_dev: Some(ScalarValue::Float64(Some(-1.))), + range: Interval::make_zero(&DataType::Float64).unwrap() + }, + false, + ), + ]; + for case in unknown_stats { + assert_eq!(case.0.is_valid(), case.1); + } + } +} From 992b3c0cecec7050d7d6c78e321c2bbf40ebb676 Mon Sep 17 00:00:00 2001 From: Sasha Syrotenko Date: Thu, 9 Jan 2025 17:46:21 +0200 Subject: [PATCH 2/2] Implement mean, median and standard deviation extraction for StatsV2 --- datafusion/expr-common/src/stats.rs | 361 +++++++++++++++++++++++++++- 1 file changed, 350 insertions(+), 11 deletions(-) diff --git a/datafusion/expr-common/src/stats.rs b/datafusion/expr-common/src/stats.rs index baac2c23345f..e35aeabbebd2 100644 --- a/datafusion/expr-common/src/stats.rs +++ b/datafusion/expr-common/src/stats.rs @@ -1,5 +1,6 @@ use crate::interval_arithmetic::Interval; -use crate::stats::StatisticsV2::{Exponential, Gaussian, Unknown}; +use crate::stats::StatisticsV2::{Exponential, Gaussian, Uniform, Unknown}; +use arrow::datatypes::DataType; use datafusion_common::ScalarValue; /// New, enhanced `Statistics` definition, represents three core definitions @@ -7,6 +8,7 @@ pub enum StatisticsV2 { Uniform { interval: Interval, }, + /// f(x, λ) = (λe)^-λx, if x >= 0 Exponential { rate: ScalarValue, offset: ScalarValue, @@ -33,12 +35,16 @@ impl StatisticsV2 { pub fn is_valid(&self) -> bool { match &self { Exponential { rate, .. } => { - if rate.is_null() { return false; } + if rate.is_null() { + return false; + } let zero = &ScalarValue::new_zero(&rate.data_type()).unwrap(); rate.gt(zero) } Gaussian { variance, .. } => { - if variance.is_null() { return false; } + if variance.is_null() { + return false; + } let zero = &ScalarValue::new_zero(&variance.data_type()).unwrap(); variance.ge(zero) } @@ -49,10 +55,14 @@ impl StatisticsV2 { range, } => { if let (Some(mn), Some(md)) = (mean, median) { - if mn.is_null() || md.is_null() { return false; } + if mn.is_null() || md.is_null() { + return false; + } range.contains_value(mn).unwrap() && range.contains_value(md).unwrap() } else if let Some(dev) = std_dev { - if dev.is_null() { return false; } + if dev.is_null() { + return false; + } dev.gt(&ScalarValue::new_zero(&dev.data_type()).unwrap()) } else { false @@ -61,6 +71,100 @@ impl StatisticsV2 { _ => true, } } + + /// Extract the mean value of given statistic, available for given statistic kinds: + /// - [`Uniform`]'s interval implicitly contains mean value, and it is calculable + /// by addition of upper and lower bound and dividing the result by 2. + /// - [`Exponential`] distribution mean is calculable by formula: 1/λ. λ must be non-negative. + /// - [`Gaussian`] distribution has it explicitly + /// - [`Unknown`] distribution _may_ have it explicitly + pub fn mean(&self) -> Option { + if !self.is_valid() { + return None; + } + match &self { + Uniform { interval, .. } => { + let aggregate = interval.lower().add_checked(interval.upper()); + if aggregate.is_err() { + // TODO: logs + return None; + } + + let float_aggregate = aggregate.unwrap().cast_to(&DataType::Float64); + if float_aggregate.is_err() { + // TODO: logs + return None; + } + + if let Ok(mean) = + float_aggregate.unwrap().div(ScalarValue::Float64(Some(2.))) + { + return Some(mean); + } + None + } + Exponential { rate, .. } => { + let one = &ScalarValue::new_one(&rate.data_type()).unwrap(); + if let Ok(mean) = one.div(rate) { + return Some(mean); + } + None + } + Gaussian { mean, .. } => Some(mean.clone()), + Unknown { mean, .. } => mean.clone(), + _ => unreachable!(), + } + } + + /// Extract the median value of given statistic, available for given statistic kinds: + /// - [`Exponential`] distribution median is calculable by formula: ln2/λ. λ must be non-negative. + /// - [`Gaussian`] distribution median is equals to mean, which is present explicitly. + /// - [`Unknown`] distribution median _may_ be present explicitly. + pub fn median(&self) -> Option { + if !self.is_valid() { + return None; + } + match &self { + Exponential { rate, .. } => { + let ln_two = ScalarValue::Float64(Some(2_f64.ln())); + if let Ok(median) = ln_two.div(rate) { + return Some(median); + } + None + } + Gaussian { mean, .. } => Some(mean.clone()), + Unknown { median, .. } => median.clone(), + _ => None, + } + } + + /// Extract the standard deviation of given statistic distribution: + /// - [`Exponential`]'s standard deviation is equal to mean value and calculable as 1/λ + /// - [`Gaussian`]'s standard deviation is a square root of variance. + /// - [`Unknown`]'s distribution standard deviation _may_ be present explicitly. + pub fn std_dev(&self) -> Option { + if !self.is_valid() { + return None; + } + match &self { + Exponential { rate, .. } => { + let one = &ScalarValue::new_one(&rate.data_type()).unwrap(); + if let Ok(std_dev) = one.div(rate) { + return Some(std_dev); + } + None + } + Gaussian { + variance: _variance, + .. + } => { + // TODO: sqrt() is not yet implemented for ScalarValue + None + } + Unknown { std_dev, .. } => std_dev.clone(), + _ => None, + } + } } // #[cfg(test)] @@ -71,8 +175,10 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::ScalarValue; + //region is_valid tests + // The test data in the following tests are placed as follows : (stat -> expected answer) #[test] - fn uniform_stats_test() { + fn uniform_stats_is_valid_test() { let uniform_stats = vec![ ( StatisticsV2::Uniform { @@ -100,7 +206,7 @@ mod tests { } #[test] - fn exponential_stats_test() { + fn exponential_stats_is_valid_test() { let exp_stats = vec![ ( StatisticsV2::Exponential { @@ -137,7 +243,7 @@ mod tests { } #[test] - fn gaussian_stats_test() { + fn gaussian_stats_is_valid_test() { let gaussian_stats = vec![ ( StatisticsV2::Gaussian { @@ -174,7 +280,7 @@ mod tests { } #[test] - fn unknown_stats_test() { + fn unknown_stats_is_valid_test() { let unknown_stats = vec![ ( StatisticsV2::Unknown { @@ -244,7 +350,7 @@ mod tests { mean: None, median: None, std_dev: Some(ScalarValue::Float64(Some(1.))), - range: Interval::make_zero(&DataType::Float64).unwrap() + range: Interval::make_zero(&DataType::Float64).unwrap(), }, true, ), @@ -253,7 +359,7 @@ mod tests { mean: None, median: None, std_dev: Some(ScalarValue::Float64(Some(-1.))), - range: Interval::make_zero(&DataType::Float64).unwrap() + range: Interval::make_zero(&DataType::Float64).unwrap(), }, false, ), @@ -262,4 +368,237 @@ mod tests { assert_eq!(case.0.is_valid(), case.1); } } + //endregion + + #[test] + fn mean_extraction_test() { + // The test data is placed as follows : (stat -> expected answer) + //region uniform + let mut stats = vec![ + ( + StatisticsV2::Uniform { + interval: Interval::make_zero(&DataType::Int64).unwrap(), + }, + Some(ScalarValue::Float64(Some(0.))), + ), + ( + StatisticsV2::Uniform { + interval: Interval::make_zero(&DataType::Float64).unwrap(), + }, + Some(ScalarValue::Float64(Some(0.))), + ), + ( + StatisticsV2::Uniform { + interval: Interval::make(Some(1), Some(100)).unwrap(), + }, + Some(ScalarValue::Float64(Some(50.5))), + ), + ( + StatisticsV2::Uniform { + interval: Interval::make(Some(-100), Some(-1)).unwrap(), + }, + Some(ScalarValue::Float64(Some(-50.5))), + ), + ( + StatisticsV2::Uniform { + interval: Interval::make(Some(-100), Some(100)).unwrap(), + }, + Some(ScalarValue::Float64(Some(0.))), + ), + ]; + //endregion + + //region exponential + stats.push(( + StatisticsV2::Exponential { + rate: ScalarValue::Float64(Some(2.)), + offset: ScalarValue::Float64(Some(0.)), + }, + Some(ScalarValue::Float64(Some(0.5))), + )); + //endregion + + // region gaussian + stats.push(( + StatisticsV2::Gaussian { + mean: ScalarValue::Float64(Some(0.)), + variance: ScalarValue::Float64(Some(1.)), + }, + Some(ScalarValue::Float64(Some(0.))), + )); + stats.push(( + StatisticsV2::Gaussian { + mean: ScalarValue::Float64(Some(-2.)), + variance: ScalarValue::Float64(Some(0.5)), + }, + Some(ScalarValue::Float64(Some(-2.))), + )); + //endregion + + //region unknown + stats.push(( + StatisticsV2::Unknown { + mean: None, + median: None, + std_dev: None, + range: Interval::make_zero(&DataType::Int8).unwrap(), + }, + None, + )); + stats.push(( + // Median is None, the statistic is not valid, correct answer is None. + StatisticsV2::Unknown { + mean: Some(ScalarValue::Float64(Some(42.))), + median: None, + std_dev: None, + range: Interval::make_zero(&DataType::Float64).unwrap(), + }, + None, + )); + stats.push(( + // Range doesn't include mean and/or median, so - not valid + StatisticsV2::Unknown { + mean: Some(ScalarValue::Float64(Some(42.))), + median: Some(ScalarValue::Float64(Some(42.))), + std_dev: None, + range: Interval::make_zero(&DataType::Float64).unwrap(), + }, + None, + )); + stats.push(( + StatisticsV2::Unknown { + mean: Some(ScalarValue::Float64(Some(42.))), + median: Some(ScalarValue::Float64(Some(42.))), + std_dev: None, + range: Interval::make(Some(25.), Some(50.)).unwrap(), + }, + Some(ScalarValue::Float64(Some(42.))), + )); + //endregion + + for case in stats { + assert_eq!(case.0.mean(), case.1); + } + } + + #[test] + fn median_extraction_test() { + // The test data is placed as follows : (stat -> expected answer) + //region uniform + let stats = vec![ + ( + StatisticsV2::Uniform { + interval: Interval::make_zero(&DataType::Int64).unwrap(), + }, + None, + ), + ( + StatisticsV2::Exponential { + rate: ScalarValue::Float64(Some(2_f64.ln())), + offset: ScalarValue::Float64(Some(0.)), + }, + Some(ScalarValue::Float64(Some(1.))), + ), + ( + StatisticsV2::Gaussian { + mean: ScalarValue::Float64(Some(2.)), + variance: ScalarValue::Float64(Some(1.)), + }, + Some(ScalarValue::Float64(Some(2.))), + ), + ( + StatisticsV2::Unknown { + mean: None, + median: None, + std_dev: None, + range: Interval::make_zero(&DataType::Int8).unwrap(), + }, + None, + ), + ( + // Mean is None, statistics is not valid + StatisticsV2::Unknown { + mean: None, + median: Some(ScalarValue::Float64(Some(12.))), + std_dev: None, + range: Interval::make_zero(&DataType::Float64).unwrap(), + }, + None, + ), + ( + // Range doesn't include mean and/or median, so - not valid + StatisticsV2::Unknown { + mean: None, + median: Some(ScalarValue::Float64(Some(12.))), + std_dev: None, + range: Interval::make_zero(&DataType::Float64).unwrap(), + }, + None, + ), + ( + StatisticsV2::Unknown { + mean: Some(ScalarValue::Float64(Some(12.))), + median: Some(ScalarValue::Float64(Some(12.))), + std_dev: None, + range: Interval::make(Some(0.), Some(25.)).unwrap(), + }, + Some(ScalarValue::Float64(Some(12.))), + ), + ]; + + for case in stats { + assert_eq!(case.0.median(), case.1); + } + } + + #[test] + fn std_dev_extraction_test() { + // The test data is placed as follows : (stat -> expected answer) + let stats = vec![ + ( + StatisticsV2::Uniform { + interval: Interval::make_zero(&DataType::Int64).unwrap(), + }, + None, + ), + ( + StatisticsV2::Exponential { + rate: ScalarValue::Float64(Some(10.)), + offset: ScalarValue::Float64(Some(0.)), + }, + Some(ScalarValue::Float64(Some(0.1))), + ), + ( + StatisticsV2::Gaussian { + mean: ScalarValue::Float64(Some(0.)), + variance: ScalarValue::Float64(Some(1.)), + }, + None, + ), + ( + StatisticsV2::Unknown { + mean: None, + median: None, + std_dev: None, + range: Interval::make_zero(&DataType::Int8).unwrap(), + }, + None, + ), + ( + StatisticsV2::Unknown { + mean: None, + median: None, + std_dev: Some(ScalarValue::Float64(Some(0.02))), + range: Interval::make_zero(&DataType::Float64).unwrap(), + }, + Some(ScalarValue::Float64(Some(0.02))), + ), + ]; + + //endregion + + for case in stats { + assert_eq!(case.0.std_dev(), case.1); + } + } }