From de7bd71ccf01f0ff428d690f6e5b5297f06c9ca4 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 29 Apr 2024 09:25:42 -0700 Subject: [PATCH] fix --- core/src/execution/datafusion/expressions/stddev.rs | 9 ++++++++- docs/source/user-guide/expressions.md | 2 ++ .../scala/org/apache/comet/serde/QueryPlanSerde.scala | 2 ++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/core/src/execution/datafusion/expressions/stddev.rs b/core/src/execution/datafusion/expressions/stddev.rs index cc0c845833..efb2449d25 100644 --- a/core/src/execution/datafusion/expressions/stddev.rs +++ b/core/src/execution/datafusion/expressions/stddev.rs @@ -31,6 +31,10 @@ use crate::execution::datafusion::expressions::utils::down_cast_any_ref; use crate::execution::datafusion::expressions::variance::VarianceAccumulator; /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression +/// The implementation mostly is the same as the DataFusion's implementation. The reason +/// we have our own implementation is that DataFusion has UInt64 for state_field `count`, +/// while Spark has Double for count. Also we have added `null_on_divide_by_zero` +/// to be consistent with Spark's implementation. #[derive(Debug)] pub struct Stddev { name: String, @@ -112,7 +116,10 @@ impl PartialEq for Stddev { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() - .map(|x| self.name == x.name && self.expr.eq(&x.expr)) + .map(|x| + self.name == x.name + && self.expr.eq(&x.expr) + && self.null_on_divide_by_zero == x.null_on_divide_by_zero) .unwrap_or(false) } } diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index f67a4eada0..38c86c7271 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -107,3 +107,5 @@ The following Spark expressions are currently available: - CovSample - VariancePop - VarianceSamp + - StddevPop + - StddevSamp diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 9ec14b1c0e..1e43795143 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -521,6 +521,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setStddev(stdBuilder) .build()) } else { + withInfo(aggExpr, child) None } case std @ StddevPop(child, nullOnDivideByZero) => @@ -540,6 +541,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setStddev(stdBuilder) .build()) } else { + withInfo(aggExpr, child) None } case fn =>