diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala index cf0fc15556a0a..c316d50f02d7c 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala @@ -480,4 +480,6 @@ object BackendSettings extends BackendSettingsApi { // vanilla Spark, we need to rewrite the aggregate to get the correct data type. true } + + override def shouldRewriteCollect(): Boolean = true } diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala index 4d4db843ea261..48b0a36c01381 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala @@ -699,6 +699,132 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu } } + test("test collect_set") { + runQueryAndCompare("SELECT array_sort(collect_set(l_partkey)) FROM lineitem") { + df => + { + assert( + getExecutedPlan(df).count( + plan => { + plan.isInstanceOf[HashAggregateExecTransformer] + }) == 2) + } + } + + runQueryAndCompare( + """ + |SELECT array_sort(collect_set(l_suppkey)), array_sort(collect_set(l_partkey)) + |FROM lineitem + |""".stripMargin) { + df => + { + assert( + getExecutedPlan(df).count( + plan => { + plan.isInstanceOf[HashAggregateExecTransformer] + }) == 2) + } + } + + runQueryAndCompare( + "SELECT count(distinct l_suppkey), array_sort(collect_set(l_partkey)) FROM lineitem") { + df => + { + assert( + getExecutedPlan(df).count( + plan => { + plan.isInstanceOf[HashAggregateExecTransformer] + }) == 4) + } + } + } + + test("test collect_set/collect_list with null") { + import testImplicits._ + + withTempView("collect_tmp") { + Seq((1, null), (1, "a"), (2, null), (3, null), (3, null), (4, "b")) + .toDF("c1", "c2") + .createOrReplaceTempView("collect_tmp") + + // basic test + runQueryAndCompare("SELECT collect_set(c2), collect_list(c2) FROM collect_tmp GROUP BY c1") { + df => + { + assert( + getExecutedPlan(df).count( + plan => { + plan.isInstanceOf[HashAggregateExecTransformer] + }) == 2) + } + } + + // test pre project and post project + runQueryAndCompare(""" + |SELECT + |size(collect_set(if(c2 = 'a', 'x', 'y'))) as x, + |size(collect_list(if(c2 = 'a', 'x', 'y'))) as y + |FROM collect_tmp GROUP BY c1 + |""".stripMargin) { + df => + { + assert( + getExecutedPlan(df).count( + plan => { + plan.isInstanceOf[HashAggregateExecTransformer] + }) == 2) + } + } + + // test distinct + runQueryAndCompare( + "SELECT collect_set(c2), collect_list(distinct c2) FROM collect_tmp GROUP BY c1") { + df => + { + assert( + getExecutedPlan(df).count( + plan => { + plan.isInstanceOf[HashAggregateExecTransformer] + }) == 4) + } + } + + // test distinct + pre project and post project + runQueryAndCompare(""" + |SELECT + |size(collect_set(if(c2 = 'a', 'x', 'y'))), + |size(collect_list(distinct if(c2 = 'a', 'x', 'y'))) + |FROM collect_tmp GROUP BY c1 + |""".stripMargin) { + df => + { + assert( + getExecutedPlan(df).count( + plan => { + plan.isInstanceOf[HashAggregateExecTransformer] + }) == 4) + } + } + + // test cast array to string + runQueryAndCompare(""" + |SELECT + |cast(collect_set(c2) as string), + |cast(collect_list(c2) as string) + |FROM collect_tmp GROUP BY c1 + |""".stripMargin) { + df => + { + assert( + getExecutedPlan(df).count( + plan => { + plan.isInstanceOf[HashAggregateExecTransformer] + }) == 2) + } + } + } + } + test("count(1)") { runQueryAndCompare( """ diff --git a/cpp/velox/substrait/SubstraitParser.cc b/cpp/velox/substrait/SubstraitParser.cc index c18e3d85930ce..479e41876e474 100644 --- a/cpp/velox/substrait/SubstraitParser.cc +++ b/cpp/velox/substrait/SubstraitParser.cc @@ -375,20 +375,12 @@ std::unordered_map SubstraitParser::substraitVeloxFunc {"starts_with", "startswith"}, {"named_struct", "row_constructor"}, {"bit_or", "bitwise_or_agg"}, - {"bit_or_partial", "bitwise_or_agg_partial"}, - {"bit_or_merge", "bitwise_or_agg_merge"}, {"bit_and", "bitwise_and_agg"}, - {"bit_and_partial", "bitwise_and_agg_partial"}, - {"bit_and_merge", "bitwise_and_agg_merge"}, {"murmur3hash", "hash_with_seed"}, {"modulus", "remainder"}, {"date_format", "format_datetime"}, {"collect_set", "set_agg"}, - {"collect_set_partial", "set_agg_partial"}, - {"collect_set_merge", "set_agg_merge"}, - {"collect_list", "array_agg"}, - {"collect_list_partial", "array_agg_partial"}, - {"collect_list_merge", "array_agg_merge"}}; + {"collect_list", "array_agg"}}; const std::unordered_map SubstraitParser::typeMap_ = { {"bool", "BOOLEAN"}, diff --git a/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala index 74973a60c6761..83db8a8da85ee 100644 --- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala @@ -132,4 +132,6 @@ trait BackendSettingsApi { def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = false def shouldRewriteTypedImperativeAggregate(): Boolean = false + + def shouldRewriteCollect(): Boolean = false } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala index 21fc6a2ff7707..f58086dd01314 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala @@ -561,6 +561,7 @@ object ColumnarOverrideRules { val rewriteRules = Seq( RewriteIn, RewriteMultiChildrenCount, + RewriteCollect, RewriteTypedImperativeAggregate, PullOutPreProject, PullOutPostProject) diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/RewriteCollect.scala b/gluten-core/src/main/scala/io/glutenproject/extension/RewriteCollect.scala new file mode 100644 index 0000000000000..b7b58acc8c80d --- /dev/null +++ b/gluten-core/src/main/scala/io/glutenproject/extension/RewriteCollect.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.glutenproject.extension + +import io.glutenproject.backendsapi.BackendsApiManager +import io.glutenproject.utils.PullOutProjectHelper + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, If, IsNotNull, IsNull, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, CollectList, CollectSet, Complete, Final, Partial} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.types.ArrayType + +import scala.collection.mutable.ArrayBuffer + +/** + * This rule rewrite collect_set and collect_list to be compatible with vanilla Spark. + * + * - Add `IsNotNull(partial_in)` to skip null value before going to native collect_set + * - Add `If(IsNull(result), CreateArray(Seq.empty), result)` to replace null to empty array + * + * TODO: remove this rule once Velox compatible with vanilla Spark. + */ +object RewriteCollect extends Rule[SparkPlan] with PullOutProjectHelper { + private lazy val shouldRewriteCollect = + BackendsApiManager.getSettings.shouldRewriteCollect() + + private def shouldAddIsNotNull(ae: AggregateExpression): Boolean = { + ae.aggregateFunction match { + case c: CollectSet if c.child.nullable => + ae.mode match { + case Partial | Complete => true + case _ => false + } + case _ => false + } + } + + private def shouldReplaceNullToEmptyArray(ae: AggregateExpression): Boolean = { + ae.aggregateFunction match { + case _: CollectSet | _: CollectList => + ae.mode match { + case Final | Complete => true + case _ => false + } + case _ => false + } + } + + private def shouldRewrite(agg: BaseAggregateExec): Boolean = { + agg.aggregateExpressions.exists { + ae => shouldAddIsNotNull(ae) || shouldReplaceNullToEmptyArray(ae) + } + } + + private def rewriteCollectFilter(aggExprs: Seq[AggregateExpression]): Seq[AggregateExpression] = { + aggExprs + .map { + aggExpr => + if (shouldAddIsNotNull(aggExpr)) { + val newFilter = + (aggExpr.filter ++ Seq(IsNotNull(aggExpr.aggregateFunction.children.head))) + .reduce(And) + aggExpr.copy(filter = Option(newFilter)) + } else { + aggExpr + } + } + } + + private def rewriteAttributesAndResultExpressions( + agg: BaseAggregateExec): (Seq[Attribute], Seq[NamedExpression]) = { + val rewriteAggExprIndices = agg.aggregateExpressions.zipWithIndex + .filter(exprAndIndex => shouldReplaceNullToEmptyArray(exprAndIndex._1)) + .map(_._2) + .toSet + if (rewriteAggExprIndices.isEmpty) { + return (agg.aggregateAttributes, agg.resultExpressions) + } + + assert(agg.aggregateExpressions.size == agg.aggregateAttributes.size) + val rewriteAggAttributes = new ArrayBuffer[Attribute]() + val newAggregateAttributes = agg.aggregateAttributes.zipWithIndex.map { + case (attr, index) => + if (rewriteAggExprIndices.contains(index)) { + rewriteAggAttributes.append(attr) + // We should mark attribute as withNullability since the collect_set and collect_set + // are not nullable but velox may return null. This is to avoid potential issue when + // the post project fallback to vanilla Spark. + attr.withNullability(true) + } else { + attr + } + } + val rewriteAggAttributeSet = AttributeSet(rewriteAggAttributes) + val newResultExpressions = agg.resultExpressions.map { + ne => + val rewritten = ne.transformUp { + case attr: Attribute if rewriteAggAttributeSet.contains(attr) => + assert(attr.dataType.isInstanceOf[ArrayType]) + If(IsNull(attr), Literal.create(Seq.empty, attr.dataType), attr) + } + assert(rewritten.isInstanceOf[NamedExpression]) + rewritten.asInstanceOf[NamedExpression] + } + (newAggregateAttributes, newResultExpressions) + } + + override def apply(plan: SparkPlan): SparkPlan = { + if (!shouldRewriteCollect) { + return plan + } + + plan match { + case agg: BaseAggregateExec if shouldRewrite(agg) => + val newAggExprs = rewriteCollectFilter(agg.aggregateExpressions) + val (newAttributes, newResultExprs) = rewriteAttributesAndResultExpressions(agg) + copyBaseAggregateExec(agg)( + newAggregateExpressions = newAggExprs, + newAggregateAttributes = newAttributes, + newResultExpressions = newResultExprs) + + case _ => plan + } + } +} diff --git a/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala b/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala index 4191e08e7ad3f..31acc96989e3a 100644 --- a/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala +++ b/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala @@ -83,24 +83,28 @@ trait PullOutProjectHelper { protected def copyBaseAggregateExec(agg: BaseAggregateExec)( newGroupingExpressions: Seq[NamedExpression] = agg.groupingExpressions, newAggregateExpressions: Seq[AggregateExpression] = agg.aggregateExpressions, + newAggregateAttributes: Seq[Attribute] = agg.aggregateAttributes, newResultExpressions: Seq[NamedExpression] = agg.resultExpressions ): BaseAggregateExec = agg match { case hash: HashAggregateExec => hash.copy( groupingExpressions = newGroupingExpressions, aggregateExpressions = newAggregateExpressions, + aggregateAttributes = newAggregateAttributes, resultExpressions = newResultExpressions ) case sort: SortAggregateExec => sort.copy( groupingExpressions = newGroupingExpressions, aggregateExpressions = newAggregateExpressions, + aggregateAttributes = newAggregateAttributes, resultExpressions = newResultExpressions ) case objectHash: ObjectHashAggregateExec => objectHash.copy( groupingExpressions = newGroupingExpressions, aggregateExpressions = newAggregateExpressions, + aggregateAttributes = newAggregateAttributes, resultExpressions = newResultExpressions ) case _ => diff --git a/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala index 1469606d01495..bda6c916149ca 100644 --- a/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala @@ -54,9 +54,7 @@ class VeloxTestSettings extends BackendTestSettings { "SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate", // Replaced with another test. "SPARK-19471: AggregationIterator does not initialize the generated result projection" + - " before using it", - // TODO: fix inconsistent behavior. - "SPARK-17641: collect functions should not collect null values" + " before using it" ) enableSuite[GlutenCastSuite] @@ -222,7 +220,6 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("from_unixtime") enableSuite[GlutenDecimalExpressionSuite] enableSuite[GlutenStringFunctionsSuite] - .exclude("SPARK-31993: concat_ws in agg function with plenty of string/array types columns") enableSuite[GlutenRegexpExpressionsSuite] enableSuite[GlutenNullExpressionsSuite] enableSuite[GlutenPredicateSuite] diff --git a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala index 355fa92012bcd..aacabf95b4d81 100644 --- a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala @@ -45,7 +45,6 @@ import org.apache.spark.sql.sources.{GlutenBucketedReadWithoutHiveSupportSuite, class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenStringFunctionsSuite] - .exclude("SPARK-31993: concat_ws in agg function with plenty of string/array types columns") enableSuite[GlutenBloomFilterAggregateQuerySuite] enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite] enableSuite[GlutenDataSourceV2DataFrameSuite] @@ -930,9 +929,7 @@ class VeloxTestSettings extends BackendTestSettings { "SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate", // Replaced with another test. "SPARK-19471: AggregationIterator does not initialize the generated result projection" + - " before using it", - // TODO: fix inconsistent behavior. - "SPARK-17641: collect functions should not collect null values" + " before using it" ) enableSuite[GlutenDataFrameAsOfJoinSuite] enableSuite[GlutenDataFrameComplexTypeSuite] diff --git a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala index 73f8ac6f661db..0dbe1d70ecd91 100644 --- a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala @@ -45,7 +45,6 @@ import org.apache.spark.sql.sources.{GlutenBucketedReadWithoutHiveSupportSuite, class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenStringFunctionsSuite] - .exclude("SPARK-31993: concat_ws in agg function with plenty of string/array types columns") enableSuite[GlutenBloomFilterAggregateQuerySuite] enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite] enableSuite[GlutenDataSourceV2DataFrameSuite] @@ -937,9 +936,7 @@ class VeloxTestSettings extends BackendTestSettings { "SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate", // Replaced with another test. "SPARK-19471: AggregationIterator does not initialize the generated result projection" + - " before using it", - // TODO: fix inconsistent behavior. - "SPARK-17641: collect functions should not collect null values" + " before using it" ) enableSuite[GlutenDataFrameAsOfJoinSuite] enableSuite[GlutenDataFrameComplexTypeSuite]