Skip to content

Commit

Permalink
[VL] Rewrite collect_set and collect_list aggregate function (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you authored and loneylee committed Mar 13, 2024
1 parent f83236c commit 8383e33
Show file tree
Hide file tree
Showing 10 changed files with 280 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -480,4 +480,6 @@ object BackendSettings extends BackendSettingsApi {
// vanilla Spark, we need to rewrite the aggregate to get the correct data type.
true
}

override def shouldRewriteCollect(): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,132 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
}
}

test("test collect_set") {
runQueryAndCompare("SELECT array_sort(collect_set(l_partkey)) FROM lineitem") {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 2)
}
}

runQueryAndCompare(
"""
|SELECT array_sort(collect_set(l_suppkey)), array_sort(collect_set(l_partkey))
|FROM lineitem
|""".stripMargin) {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 2)
}
}

runQueryAndCompare(
"SELECT count(distinct l_suppkey), array_sort(collect_set(l_partkey)) FROM lineitem") {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 4)
}
}
}

test("test collect_set/collect_list with null") {
import testImplicits._

withTempView("collect_tmp") {
Seq((1, null), (1, "a"), (2, null), (3, null), (3, null), (4, "b"))
.toDF("c1", "c2")
.createOrReplaceTempView("collect_tmp")

// basic test
runQueryAndCompare("SELECT collect_set(c2), collect_list(c2) FROM collect_tmp GROUP BY c1") {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 2)
}
}

// test pre project and post project
runQueryAndCompare("""
|SELECT
|size(collect_set(if(c2 = 'a', 'x', 'y'))) as x,
|size(collect_list(if(c2 = 'a', 'x', 'y'))) as y
|FROM collect_tmp GROUP BY c1
|""".stripMargin) {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 2)
}
}

// test distinct
runQueryAndCompare(
"SELECT collect_set(c2), collect_list(distinct c2) FROM collect_tmp GROUP BY c1") {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 4)
}
}

// test distinct + pre project and post project
runQueryAndCompare("""
|SELECT
|size(collect_set(if(c2 = 'a', 'x', 'y'))),
|size(collect_list(distinct if(c2 = 'a', 'x', 'y')))
|FROM collect_tmp GROUP BY c1
|""".stripMargin) {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 4)
}
}

// test cast array to string
runQueryAndCompare("""
|SELECT
|cast(collect_set(c2) as string),
|cast(collect_list(c2) as string)
|FROM collect_tmp GROUP BY c1
|""".stripMargin) {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 2)
}
}
}
}

test("count(1)") {
runQueryAndCompare(
"""
Expand Down
10 changes: 1 addition & 9 deletions cpp/velox/substrait/SubstraitParser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,20 +375,12 @@ std::unordered_map<std::string, std::string> SubstraitParser::substraitVeloxFunc
{"starts_with", "startswith"},
{"named_struct", "row_constructor"},
{"bit_or", "bitwise_or_agg"},
{"bit_or_partial", "bitwise_or_agg_partial"},
{"bit_or_merge", "bitwise_or_agg_merge"},
{"bit_and", "bitwise_and_agg"},
{"bit_and_partial", "bitwise_and_agg_partial"},
{"bit_and_merge", "bitwise_and_agg_merge"},
{"murmur3hash", "hash_with_seed"},
{"modulus", "remainder"},
{"date_format", "format_datetime"},
{"collect_set", "set_agg"},
{"collect_set_partial", "set_agg_partial"},
{"collect_set_merge", "set_agg_merge"},
{"collect_list", "array_agg"},
{"collect_list_partial", "array_agg_partial"},
{"collect_list_merge", "array_agg_merge"}};
{"collect_list", "array_agg"}};

const std::unordered_map<std::string, std::string> SubstraitParser::typeMap_ = {
{"bool", "BOOLEAN"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,6 @@ trait BackendSettingsApi {
def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = false

def shouldRewriteTypedImperativeAggregate(): Boolean = false

def shouldRewriteCollect(): Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ object ColumnarOverrideRules {
val rewriteRules = Seq(
RewriteIn,
RewriteMultiChildrenCount,
RewriteCollect,
RewriteTypedImperativeAggregate,
PullOutPreProject,
PullOutPostProject)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.glutenproject.extension

import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.utils.PullOutProjectHelper

import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, If, IsNotNull, IsNull, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, CollectList, CollectSet, Complete, Final, Partial}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.types.ArrayType

import scala.collection.mutable.ArrayBuffer

/**
* This rule rewrite collect_set and collect_list to be compatible with vanilla Spark.
*
* - Add `IsNotNull(partial_in)` to skip null value before going to native collect_set
* - Add `If(IsNull(result), CreateArray(Seq.empty), result)` to replace null to empty array
*
* TODO: remove this rule once Velox compatible with vanilla Spark.
*/
object RewriteCollect extends Rule[SparkPlan] with PullOutProjectHelper {
private lazy val shouldRewriteCollect =
BackendsApiManager.getSettings.shouldRewriteCollect()

private def shouldAddIsNotNull(ae: AggregateExpression): Boolean = {
ae.aggregateFunction match {
case c: CollectSet if c.child.nullable =>
ae.mode match {
case Partial | Complete => true
case _ => false
}
case _ => false
}
}

private def shouldReplaceNullToEmptyArray(ae: AggregateExpression): Boolean = {
ae.aggregateFunction match {
case _: CollectSet | _: CollectList =>
ae.mode match {
case Final | Complete => true
case _ => false
}
case _ => false
}
}

private def shouldRewrite(agg: BaseAggregateExec): Boolean = {
agg.aggregateExpressions.exists {
ae => shouldAddIsNotNull(ae) || shouldReplaceNullToEmptyArray(ae)
}
}

private def rewriteCollectFilter(aggExprs: Seq[AggregateExpression]): Seq[AggregateExpression] = {
aggExprs
.map {
aggExpr =>
if (shouldAddIsNotNull(aggExpr)) {
val newFilter =
(aggExpr.filter ++ Seq(IsNotNull(aggExpr.aggregateFunction.children.head)))
.reduce(And)
aggExpr.copy(filter = Option(newFilter))
} else {
aggExpr
}
}
}

private def rewriteAttributesAndResultExpressions(
agg: BaseAggregateExec): (Seq[Attribute], Seq[NamedExpression]) = {
val rewriteAggExprIndices = agg.aggregateExpressions.zipWithIndex
.filter(exprAndIndex => shouldReplaceNullToEmptyArray(exprAndIndex._1))
.map(_._2)
.toSet
if (rewriteAggExprIndices.isEmpty) {
return (agg.aggregateAttributes, agg.resultExpressions)
}

assert(agg.aggregateExpressions.size == agg.aggregateAttributes.size)
val rewriteAggAttributes = new ArrayBuffer[Attribute]()
val newAggregateAttributes = agg.aggregateAttributes.zipWithIndex.map {
case (attr, index) =>
if (rewriteAggExprIndices.contains(index)) {
rewriteAggAttributes.append(attr)
// We should mark attribute as withNullability since the collect_set and collect_set
// are not nullable but velox may return null. This is to avoid potential issue when
// the post project fallback to vanilla Spark.
attr.withNullability(true)
} else {
attr
}
}
val rewriteAggAttributeSet = AttributeSet(rewriteAggAttributes)
val newResultExpressions = agg.resultExpressions.map {
ne =>
val rewritten = ne.transformUp {
case attr: Attribute if rewriteAggAttributeSet.contains(attr) =>
assert(attr.dataType.isInstanceOf[ArrayType])
If(IsNull(attr), Literal.create(Seq.empty, attr.dataType), attr)
}
assert(rewritten.isInstanceOf[NamedExpression])
rewritten.asInstanceOf[NamedExpression]
}
(newAggregateAttributes, newResultExpressions)
}

override def apply(plan: SparkPlan): SparkPlan = {
if (!shouldRewriteCollect) {
return plan
}

plan match {
case agg: BaseAggregateExec if shouldRewrite(agg) =>
val newAggExprs = rewriteCollectFilter(agg.aggregateExpressions)
val (newAttributes, newResultExprs) = rewriteAttributesAndResultExpressions(agg)
copyBaseAggregateExec(agg)(
newAggregateExpressions = newAggExprs,
newAggregateAttributes = newAttributes,
newResultExpressions = newResultExprs)

case _ => plan
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,24 +83,28 @@ trait PullOutProjectHelper {
protected def copyBaseAggregateExec(agg: BaseAggregateExec)(
newGroupingExpressions: Seq[NamedExpression] = agg.groupingExpressions,
newAggregateExpressions: Seq[AggregateExpression] = agg.aggregateExpressions,
newAggregateAttributes: Seq[Attribute] = agg.aggregateAttributes,
newResultExpressions: Seq[NamedExpression] = agg.resultExpressions
): BaseAggregateExec = agg match {
case hash: HashAggregateExec =>
hash.copy(
groupingExpressions = newGroupingExpressions,
aggregateExpressions = newAggregateExpressions,
aggregateAttributes = newAggregateAttributes,
resultExpressions = newResultExpressions
)
case sort: SortAggregateExec =>
sort.copy(
groupingExpressions = newGroupingExpressions,
aggregateExpressions = newAggregateExpressions,
aggregateAttributes = newAggregateAttributes,
resultExpressions = newResultExpressions
)
case objectHash: ObjectHashAggregateExec =>
objectHash.copy(
groupingExpressions = newGroupingExpressions,
aggregateExpressions = newAggregateExpressions,
aggregateAttributes = newAggregateAttributes,
resultExpressions = newResultExpressions
)
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ class VeloxTestSettings extends BackendTestSettings {
"SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate",
// Replaced with another test.
"SPARK-19471: AggregationIterator does not initialize the generated result projection" +
" before using it",
// TODO: fix inconsistent behavior.
"SPARK-17641: collect functions should not collect null values"
" before using it"
)

enableSuite[GlutenCastSuite]
Expand Down Expand Up @@ -222,7 +220,6 @@ class VeloxTestSettings extends BackendTestSettings {
.exclude("from_unixtime")
enableSuite[GlutenDecimalExpressionSuite]
enableSuite[GlutenStringFunctionsSuite]
.exclude("SPARK-31993: concat_ws in agg function with plenty of string/array types columns")
enableSuite[GlutenRegexpExpressionsSuite]
enableSuite[GlutenNullExpressionsSuite]
enableSuite[GlutenPredicateSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ import org.apache.spark.sql.sources.{GlutenBucketedReadWithoutHiveSupportSuite,

class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenStringFunctionsSuite]
.exclude("SPARK-31993: concat_ws in agg function with plenty of string/array types columns")
enableSuite[GlutenBloomFilterAggregateQuerySuite]
enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite]
enableSuite[GlutenDataSourceV2DataFrameSuite]
Expand Down Expand Up @@ -930,9 +929,7 @@ class VeloxTestSettings extends BackendTestSettings {
"SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate",
// Replaced with another test.
"SPARK-19471: AggregationIterator does not initialize the generated result projection" +
" before using it",
// TODO: fix inconsistent behavior.
"SPARK-17641: collect functions should not collect null values"
" before using it"
)
enableSuite[GlutenDataFrameAsOfJoinSuite]
enableSuite[GlutenDataFrameComplexTypeSuite]
Expand Down
Loading

0 comments on commit 8383e33

Please sign in to comment.