From 98ab12b052e5207f47106e3f37333863f6567b86 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Thu, 29 Feb 2024 17:50:56 +0800 Subject: [PATCH] Support collect_set --- .../backendsapi/velox/VeloxBackend.scala | 2 + .../VeloxAggregateFunctionsSuite.scala | 40 ++++++++++ cpp/velox/substrait/SubstraitParser.cc | 10 +-- ep/build-velox/src/get_velox.sh | 4 +- .../backendsapi/BackendSettingsApi.scala | 2 + .../extension/ColumnarOverrides.scala | 1 + .../extension/RewriteCollectSet.scala | 76 +++++++++++++++++++ 7 files changed, 124 insertions(+), 11 deletions(-) create mode 100644 gluten-core/src/main/scala/io/glutenproject/extension/RewriteCollectSet.scala diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala index 939ea3d7617df..6977f96b57b70 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala @@ -449,4 +449,6 @@ object BackendSettings extends BackendSettingsApi { // vanilla Spark, we need to rewrite the aggregate to get the correct data type. true } + + override def shouldRewriteCollectSet(): Boolean = true } diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala index 85421b138a590..8d2bb45e2683c 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala @@ -699,6 +699,46 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu } } + test("test collect_set") { + runQueryAndCompare("SELECT array_sort(collect_set(l_partkey)) FROM lineitem") { + df => + { + assert( + getExecutedPlan(df).count( + plan => { + plan.isInstanceOf[HashAggregateExecTransformer] + }) == 2) + } + } + + runQueryAndCompare( + """ + |SELECT array_sort(collect_set(l_suppkey)), array_sort(collect_set(l_partkey)) + |FROM lineitem + |""".stripMargin) { + df => + { + assert( + getExecutedPlan(df).count( + plan => { + plan.isInstanceOf[HashAggregateExecTransformer] + }) == 2) + } + } + + runQueryAndCompare( + "SELECT count(distinct l_suppkey), array_sort(collect_set(l_partkey)) FROM lineitem") { + df => + { + assert( + getExecutedPlan(df).count( + plan => { + plan.isInstanceOf[HashAggregateExecTransformer] + }) == 4) + } + } + } + test("count(1)") { runQueryAndCompare( """ diff --git a/cpp/velox/substrait/SubstraitParser.cc b/cpp/velox/substrait/SubstraitParser.cc index 8281e90f42dcc..c38c9aa65f420 100644 --- a/cpp/velox/substrait/SubstraitParser.cc +++ b/cpp/velox/substrait/SubstraitParser.cc @@ -368,20 +368,12 @@ std::unordered_map SubstraitParser::substraitVeloxFunc {"starts_with", "startswith"}, {"named_struct", "row_constructor"}, {"bit_or", "bitwise_or_agg"}, - {"bit_or_partial", "bitwise_or_agg_partial"}, - {"bit_or_merge", "bitwise_or_agg_merge"}, {"bit_and", "bitwise_and_agg"}, - {"bit_and_partial", "bitwise_and_agg_partial"}, - {"bit_and_merge", "bitwise_and_agg_merge"}, {"murmur3hash", "hash_with_seed"}, {"modulus", "remainder"}, {"date_format", "format_datetime"}, {"collect_set", "set_agg"}, - {"collect_set_partial", "set_agg_partial"}, - {"collect_set_merge", "set_agg_merge"}, - {"collect_list", "array_agg"}, - {"collect_list_partial", "array_agg_partial"}, - {"collect_list_merge", "array_agg_merge"}}; + {"collect_list", "array_agg"}}; const std::unordered_map SubstraitParser::typeMap_ = { {"bool", "BOOLEAN"}, diff --git a/ep/build-velox/src/get_velox.sh b/ep/build-velox/src/get_velox.sh index df92be5030b31..568987dbe8ad9 100755 --- a/ep/build-velox/src/get_velox.sh +++ b/ep/build-velox/src/get_velox.sh @@ -16,8 +16,8 @@ set -exu -VELOX_REPO=https://github.com/oap-project/velox.git -VELOX_BRANCH=2024_02_28 +VELOX_REPO=https://github.com/ulysses-you/velox.git +VELOX_BRANCH=setagg VELOX_HOME="" #Set on run gluten on HDFS diff --git a/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala index 74973a60c6761..20cedb9b6336c 100644 --- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala @@ -132,4 +132,6 @@ trait BackendSettingsApi { def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = false def shouldRewriteTypedImperativeAggregate(): Boolean = false + + def shouldRewriteCollectSet(): Boolean = false } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala index 685639a7e59e1..feb5479317ecf 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala @@ -568,6 +568,7 @@ object ColumnarOverrideRules { def rewriteSparkPlanRule(): Rule[SparkPlan] = { val rewriteRules = Seq( RewriteMultiChildrenCount, + RewriteCollectSet, RewriteTypedImperativeAggregate, PullOutPreProject, PullOutPostProject) diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/RewriteCollectSet.scala b/gluten-core/src/main/scala/io/glutenproject/extension/RewriteCollectSet.scala new file mode 100644 index 0000000000000..44e52d5f08863 --- /dev/null +++ b/gluten-core/src/main/scala/io/glutenproject/extension/RewriteCollectSet.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.glutenproject.extension + +import io.glutenproject.backendsapi.BackendsApiManager +import io.glutenproject.utils.PullOutProjectHelper + +import org.apache.spark.sql.catalyst.expressions.{And, IsNotNull} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, CollectSet, Complete, Partial} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec + +/** + * This rule Add `IsNotNull` to skip null value before going to native collect_set + * + * TODO: remove this rule once Velox collect_set skip null value + */ +object RewriteCollectSet extends Rule[SparkPlan] with PullOutProjectHelper { + private lazy val shouldRewriteCollect = + BackendsApiManager.getSettings.shouldRewriteCollectSet() + + private def shouldRewrite(ae: AggregateExpression): Boolean = { + ae.aggregateFunction match { + case _: CollectSet => + ae.mode match { + case Partial | Complete => true + case _ => false + } + case _ => false + } + } + + private def rewriteCollectFilter(aggExprs: Seq[AggregateExpression]): Seq[AggregateExpression] = { + aggExprs + .map { + aggExpr => + if (shouldRewrite(aggExpr)) { + val newFilter = + (aggExpr.filter ++ Seq(IsNotNull(aggExpr.aggregateFunction.children.head))) + .reduce(And) + aggExpr.copy(filter = Option(newFilter)) + } else { + aggExpr + } + } + } + + override def apply(plan: SparkPlan): SparkPlan = { + if (!shouldRewriteCollect) { + return plan + } + + plan match { + case agg: BaseAggregateExec if agg.aggregateExpressions.exists(shouldRewrite) => + val newAggExprs = rewriteCollectFilter(agg.aggregateExpressions) + copyBaseAggregateExec(agg)(newAggregateExpressions = newAggExprs) + + case _ => plan + } + } +}