From ab378628afefb397595fa87583292f9b9345f915 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Tue, 26 Nov 2024 10:49:16 +0800 Subject: [PATCH] fix bugs --- .../extension/ConvertWindowToAggregate.scala | 39 +++++++---------- ...enClickHouseTPCHSaltNullParquetSuite.scala | 43 ++++++++++++++----- 2 files changed, 49 insertions(+), 33 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ConvertWindowToAggregate.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ConvertWindowToAggregate.scala index 349a66e74a1c..aaf570411850 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ConvertWindowToAggregate.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ConvertWindowToAggregate.scala @@ -50,12 +50,6 @@ case class ConverRowNumbertWindowToAggregateRule(spark: SparkSession) partitionSpec, orderSpec, sort @ SortExecTransformer(_, _, _, _))) => - logDebug(s"xxx condition: $condition") - logDebug(s"xxx windowExpressions: $windowExpressions") - logDebug(s"xxx partitionSpec: $partitionSpec") - logDebug(s"xxx orderSpec: $orderSpec") - logDebug(s"xxx window output: ${window.output}") - logDebug(s"xxx child: ${sort.child.getClass}") if ( !isSupportedWindowFunction(windowExpressions) || !isTopKLimitFilter( condition, @@ -65,24 +59,23 @@ case class ConverRowNumbertWindowToAggregateRule(spark: SparkSession) s"xxx Not Supported case for converting window to aggregate. is topk limit: " + s"${isTopKLimitFilter(condition, windowExpressions(0))}. is supported window " + s"function: ${isSupportedWindowFunction(windowExpressions)}") - return filter - } - val limit = getLimit(condition.asInstanceOf[BinaryComparison]) - if (limit < 1 || limit > 100) { - return filter + filter + } else { + val limit = getLimit(condition.asInstanceOf[BinaryComparison]) + if (limit < 1 || limit > 100) { + filter + } else { + val groupLimit = CHAggregateGroupLimitExecTransformer( + partitionSpec, + orderSpec, + extractWindowFunction(windowExpressions(0)), + sort.child.output ++ Seq(windowExpressions(0).toAttribute), + limit, + sort.child + ) + groupLimit + } } - val groupLimit = CHAggregateGroupLimitExecTransformer( - partitionSpec, - orderSpec, - extractWindowFunction(windowExpressions(0)), - sort.child.output ++ Seq(windowExpressions(0).toAttribute), - limit, - sort.child - ) - logDebug(s"xxx windowGroupLimit: $groupLimit") - logDebug(s"xxx original window output: ${window.output}") - logDebug(s"xxx windowGroupLimit output: ${groupLimit.output}") - groupLimit } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala index dceb98bc6534..a61dedce15f7 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -3183,40 +3183,51 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr |""".stripMargin) compareResultsAgainstVanillaSpark( """ - |select a, b, c, row_number() over (partition by a order by b desc nulls first, c nulls last) as r - |from test_win_top + |select * from( + |select a, b, c, + |row_number() over (partition by a order by b desc nulls first, c nulls last) as r + |from test_win_top) + |where r <= 1 |""".stripMargin, true, checkWindowGroupLimit ) compareResultsAgainstVanillaSpark( """ + |select * from( |select a, b, c, row_number() over (partition by a order by b desc, c nulls last) as r |from test_win_top + |)where r <= 1 |""".stripMargin, true, checkWindowGroupLimit ) compareResultsAgainstVanillaSpark( """ + |select * from( |select a, b, c, row_number() over (partition by a order by b asc nulls first, c) as r - |from test_win_top + |from test_win_top) + |where r <= 1 |""".stripMargin, true, checkWindowGroupLimit ) compareResultsAgainstVanillaSpark( """ + |select * from( |select a, b, c, row_number() over (partition by a order by b asc nulls last) as r - |from test_win_top + |from test_win_top) + |where r <= 1 |""".stripMargin, true, checkWindowGroupLimit ) compareResultsAgainstVanillaSpark( """ + |select * from( |select a, b, c, row_number() over (partition by a order by b , c) as r - |from test_win_top + |from test_win_top) + |where r <= 1 |""".stripMargin, true, checkWindowGroupLimit @@ -3238,6 +3249,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr } assert(expands.size == 1) } + spark.sql("drop table if exists test_win_top") spark.sql("create table test_win_top (a string, b int, c int) using parquet") spark.sql(""" |insert into test_win_top values @@ -3247,40 +3259,51 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr |""".stripMargin) compareResultsAgainstVanillaSpark( """ - |select a, b, c, row_number() over (partition by a order by b desc nulls first, c nulls last) as r + |select * from( + |select a, b, c, + |row_number() over (partition by a order by b desc nulls first, c nulls last) as r |from test_win_top + |)where r <= 1 |""".stripMargin, true, checkWindowGroupLimit ) compareResultsAgainstVanillaSpark( """ + |select * from( |select a, b, c, row_number() over (partition by a order by b desc, c nulls last) as r - |from test_win_top + |from test_win_top) + |where r <= 1 |""".stripMargin, true, checkWindowGroupLimit ) compareResultsAgainstVanillaSpark( """ + | select * from( |select a, b, c, row_number() over (partition by a order by b asc nulls first, c) as r - |from test_win_top + |from test_win_top) + |where r <= 1 |""".stripMargin, true, checkWindowGroupLimit ) compareResultsAgainstVanillaSpark( """ + |select * from( |select a, b, c, row_number() over (partition by a order by b asc nulls last) as r - |from test_win_top + |from test_win_top) + |where r <= 1 |""".stripMargin, true, checkWindowGroupLimit ) compareResultsAgainstVanillaSpark( """ + |select * from( |select a, b, c, row_number() over (partition by a order by b , c) as r - |from test_win_top + |from test_win_top) + |where r <= 1 |""".stripMargin, true, checkWindowGroupLimit