Skip to content

Commit

Permalink
combine COVAR_SAMP and COVAR_POP
Browse files Browse the repository at this point in the history
  • Loading branch information
Huaxin Gao committed Apr 10, 2024
1 parent 263c8af commit 1ffff8b
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 92 deletions.
100 changes: 10 additions & 90 deletions core/src/execution/datafusion/expressions/covariance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,18 @@ use datafusion_common::{
};
use datafusion_physical_expr::{
aggregate::utils::down_cast_any_ref,
expressions::{format_state_name, StatsType},
expressions::format_state_name,
AggregateExpr, PhysicalExpr,
};
use crate::execution::datafusion::expressions::stats::StatsType;

/// COVAR and COVAR_SAMP aggregate expression
/// COVAR_SAMP and COVAR_POP aggregate expression
#[derive(Debug, Clone)]
pub struct Covariance {
name: String,
expr1: Arc<dyn PhysicalExpr>,
expr2: Arc<dyn PhysicalExpr>,
}

/// COVAR_POP aggregate expression
#[derive(Debug)]
pub struct CovariancePop {
name: String,
expr1: Arc<dyn PhysicalExpr>,
expr2: Arc<dyn PhysicalExpr>,
stats_type: StatsType,
}

impl Covariance {
Expand All @@ -57,13 +51,15 @@ impl Covariance {
expr2: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
stats_type: StatsType,
) -> Self {
// the result of covariance just support FLOAT64 data type.
assert!(matches!(data_type, DataType::Float64));
Self {
name: name.into(),
expr1,
expr2,
stats_type,
}
}
}
Expand All @@ -79,7 +75,7 @@ impl AggregateExpr for Covariance {
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?))
Ok(Box::new(CovarianceAccumulator::try_new(self.stats_type)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Expand Down Expand Up @@ -120,84 +116,8 @@ impl PartialEq<dyn Any> for Covariance {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2))
.unwrap_or(false)
}
}

impl CovariancePop {
/// Create a new COVAR_POP aggregate function
pub fn new(
expr1: Arc<dyn PhysicalExpr>,
expr2: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
) -> Self {
// the result of covariance just support FLOAT64 data type.
assert!(matches!(data_type, DataType::Float64));
Self {
name: name.into(),
expr1,
expr2,
}
}
}

impl AggregateExpr for CovariancePop {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn field(&self) -> Result<Field> {
Ok(Field::new(&self.name, DataType::Float64, true))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(CovarianceAccumulator::try_new(
StatsType::Population,
)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![
Field::new(
format_state_name(&self.name, "count"),
DataType::Float64,
true,
),
Field::new(
format_state_name(&self.name, "mean1"),
DataType::Float64,
true,
),
Field::new(
format_state_name(&self.name, "mean2"),
DataType::Float64,
true,
),
Field::new(
format_state_name(&self.name, "algo_const"),
DataType::Float64,
true,
),
])
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr1.clone(), self.expr2.clone()]
}

fn name(&self) -> &str {
&self.name
}
}

impl PartialEq<dyn Any> for CovariancePop {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2))
.map(|x| self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2)
&& self.stats_type == x.stats_type)
.unwrap_or(false)
}
}
Expand Down Expand Up @@ -362,7 +282,7 @@ impl Accumulator for CovarianceAccumulator {

fn evaluate(&mut self) -> Result<ScalarValue> {
let count = match self.stats_type {
datafusion_physical_expr::expressions::StatsType::Population => self.count,
StatsType::Population => self.count,
StatsType::Sample => {
if self.count > 0.0 {
self.count - 1.0
Expand Down
1 change: 1 addition & 0 deletions core/src/execution/datafusion/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub mod avg;
pub mod avg_decimal;
pub mod bloom_filter_might_contain;
pub mod covariance;
pub mod stats;
pub mod strings;
pub mod subquery;
pub mod sum_decimal;
Expand Down
27 changes: 27 additions & 0 deletions core/src/execution/datafusion/expressions/stats.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/// Enum used for differentiating population and sample for statistical functions
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub enum StatsType {
/// Population
Population,
/// Sample
Sample,
}
7 changes: 5 additions & 2 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ use crate::{
bloom_filter_might_contain::BloomFilterMightContain,
cast::Cast,
checkoverflow::CheckOverflow,
covariance::{Covariance, CovariancePop},
covariance::Covariance,
if_expr::IfExpr,
scalar_funcs::create_comet_physical_fun,
strings::{Contains, EndsWith, Like, StartsWith, StringSpaceExec, SubstringExec},
Expand All @@ -86,6 +86,7 @@ use crate::{
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
},
};
use crate::execution::datafusion::expressions::stats::StatsType;

// For clippy error on type_complexity.
type ExecResult<T> = Result<T, ExecutionError>;
Expand Down Expand Up @@ -1177,17 +1178,19 @@ impl PhysicalPlanner {
child2,
"covariance",
datatype,
StatsType::Sample
)))
}
AggExprStruct::CovPopulation(expr) => {
let child1 = self.create_expr(expr.child1.as_ref().unwrap(), schema.clone())?;
let child2 = self.create_expr(expr.child2.as_ref().unwrap(), schema.clone())?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
Ok(Arc::new(CovariancePop::new(
Ok(Arc::new(Covariance::new(
child1,
child2,
"covariance_pop",
datatype,
StatsType::Population
)))
}
}
Expand Down

0 comments on commit 1ffff8b

Please sign in to comment.