diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 8984a9551b03..5975a20a267d 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -35,7 +35,7 @@ import org.apache.spark.shuffle.utils.ShuffleUtil import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, CollectList, CollectSet} import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical._ @@ -759,7 +759,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { Sig[UserDefinedAggregateFunction](ExpressionNames.UDAF_PLACEHOLDER), Sig[NaNvl](ExpressionNames.NANVL), Sig[VeloxCollectList](ExpressionNames.COLLECT_LIST), + Sig[CollectList](ExpressionNames.COLLECT_LIST), Sig[VeloxCollectSet](ExpressionNames.COLLECT_SET), + Sig[CollectSet](ExpressionNames.COLLECT_SET), Sig[VeloxBloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN), Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG), // For test purpose. diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala index 48541b234e36..e76de56374f5 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala @@ -20,11 +20,11 @@ import org.apache.gluten.expression.ExpressionMappings import org.apache.gluten.expression.aggregate.{VeloxCollectList, VeloxCollectSet} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.{Expression, WindowExpression} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Window} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, AGGREGATE_EXPRESSION, WINDOW, WINDOW_EXPRESSION} +import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, AGGREGATE_EXPRESSION} import scala.reflect.{classTag, ClassTag} @@ -40,7 +40,7 @@ case class CollectRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] { return plan } - val newPlan = plan.transformUpWithPruning(_.containsAnyPattern(WINDOW, AGGREGATE)) { + val newPlan = plan.transformUpWithPruning(_.containsPattern(AGGREGATE)) { case node => replaceAggCollect(node) } @@ -57,12 +57,6 @@ case class CollectRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] { case ToVeloxCollect(newAggExpr) => newAggExpr } - case w: Window => - w.transformExpressionsWithPruning( - _.containsAllPatterns(AGGREGATE_EXPRESSION, WINDOW_EXPRESSION)) { - case windowExpr @ WindowExpression(ToVeloxCollect(newAggExpr), _) => - windowExpr.copy(newAggExpr) - } case other => other } }