diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ConvertWindowToAggregate.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ConvertWindowToAggregate.scala index aaf570411850..ad2b22ba6cc6 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ConvertWindowToAggregate.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ConvertWindowToAggregate.scala @@ -24,12 +24,9 @@ import org.apache.gluten.expression.WindowFunctionsBuilder import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions._ -// import org.apache.spark.sql.catalyst.expressions.aggregate._ -// import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution._ +// import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkPlan -// import org.apache.spark.sql.execution.window.Final import org.apache.spark.sql.types._ // When to find the first rows of partitions by window function, we can convert it to aggregate @@ -103,11 +100,9 @@ case class ConverRowNumbertWindowToAggregateRule(spark: SparkSession) def isSupportedWindowFunction(windowExpressions: Seq[NamedExpression]): Boolean = { if (windowExpressions.length != 1) { - logDebug(s"xxx windowExpressions length: ${windowExpressions.length}") return false } val windowFunction = extractWindowFunction(windowExpressions(0)) - logDebug(s"xxx windowFunction: $windowFunction") windowFunction match { case _: RowNumber => true case _ => false diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala index a61dedce15f7..48ea7a261a1f 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -3168,9 +3168,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr (runtimeConfigPrefix + "window.aggregate_topk_high_cardinality_threshold", "2.0")) { def checkWindowGroupLimit(df: DataFrame): Unit = { val expands = collectWithSubqueries(df.queryExecution.executedPlan) { - case e: ExpandExecTransformer - if (e.child.isInstanceOf[CHAggregateGroupLimitExecTransformer]) => - e + case e: CHAggregateGroupLimitExecTransformer => e } assert(expands.size == 1) } @@ -3243,9 +3241,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr (runtimeConfigPrefix + "window.aggregate_topk_high_cardinality_threshold", "0.0")) { def checkWindowGroupLimit(df: DataFrame): Unit = { val expands = collectWithSubqueries(df.queryExecution.executedPlan) { - case e: ExpandExecTransformer - if (e.child.isInstanceOf[CHAggregateGroupLimitExecTransformer]) => - e + case e: CHAggregateGroupLimitExecTransformer => e } assert(expands.size == 1) } diff --git a/cpp-ch/local-engine/Common/SortUtils.cpp b/cpp-ch/local-engine/Common/SortUtils.cpp new file mode 100644 index 000000000000..1b18cc4bfaf5 --- /dev/null +++ b/cpp-ch/local-engine/Common/SortUtils.cpp @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "SortUtils.h" +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ +extern const int BAD_ARGUMENTS; +extern const int LOGICAL_ERROR; +} + +namespace local_engine +{ +DB::SortDescription parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField & expressions) +{ + DB::SortDescription description; + for (const auto & expr : expressions) + if (expr.has_selection()) + { + auto pos = expr.selection().direct_reference().struct_field().field(); + const auto & col_name = header.getByPosition(pos).name; + description.push_back(DB::SortColumnDescription(col_name, 1, 1)); + } + else if (expr.has_literal()) + continue; + else + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow expression as sort field: {}", expr.DebugString()); + return description; +} + +DB::SortDescription parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField & sort_fields) +{ + static std::map> direction_map = {{1, {1, -1}}, {2, {1, 1}}, {3, {-1, 1}}, {4, {-1, -1}}}; + + DB::SortDescription sort_descr; + for (int i = 0, sz = sort_fields.size(); i < sz; ++i) + { + const auto & sort_field = sort_fields[i]; + /// There is no meaning to sort a const column. + if (sort_field.expr().has_literal()) + continue; + + if (!sort_field.expr().has_selection() || !sort_field.expr().selection().has_direct_reference() + || !sort_field.expr().selection().direct_reference().has_struct_field()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupport sort field"); + } + auto field_pos = sort_field.expr().selection().direct_reference().struct_field().field(); + + auto direction_iter = direction_map.find(sort_field.direction()); + if (direction_iter == direction_map.end()) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsuppor sort direction: {}", sort_field.direction()); + const auto & col_name = header.getByPosition(field_pos).name; + sort_descr.emplace_back(col_name, direction_iter->second.first, direction_iter->second.second); + } + return sort_descr; +} + +std::string +buildSQLLikeSortDescription(const DB::Block & header, const google::protobuf::RepeatedPtrField & sort_fields) +{ + static const std::unordered_map order_directions + = {{1, " asc nulls first"}, {2, " asc nulls last"}, {3, " desc nulls first"}, {4, " desc nulls last"}}; + size_t n = 0; + DB::WriteBufferFromOwnString ostr; + for (const auto & sort_field : sort_fields) + { + auto it = order_directions.find(sort_field.direction()); + if (it == order_directions.end()) + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow sort direction: {}", sort_field.direction()); + if (!sort_field.expr().has_selection()) + { + throw DB::Exception( + DB::ErrorCodes::BAD_ARGUMENTS, "Sort field must be a column reference. but got {}", sort_field.DebugString()); + } + auto ref = sort_field.expr().selection().direct_reference().struct_field().field(); + const auto & col_name = header.getByPosition(ref).name; + if (n) + ostr << String(","); + // the col_name may contain '#' which can may ch fail to parse. + ostr << "`" << col_name << "`" << it->second; + n += 1; + } + LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Order by clasue: {}", ostr.str()); + return ostr.str(); +} +} diff --git a/cpp-ch/local-engine/Common/SortUtils.h b/cpp-ch/local-engine/Common/SortUtils.h new file mode 100644 index 000000000000..c460fa758b6d --- /dev/null +++ b/cpp-ch/local-engine/Common/SortUtils.h @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include + +namespace local_engine +{ +// convert expressions into sort description +DB::SortDescription +parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField & expressions); +DB::SortDescription parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField & sort_fields); + +std::string +buildSQLLikeSortDescription(const DB::Block & header, const google::protobuf::RepeatedPtrField & sort_fields); +} diff --git a/cpp-ch/local-engine/Operator/BranchStep.cpp b/cpp-ch/local-engine/Operator/BranchStep.cpp index faf05ebeaeb3..5e379ae9d4dc 100644 --- a/cpp-ch/local-engine/Operator/BranchStep.cpp +++ b/cpp-ch/local-engine/Operator/BranchStep.cpp @@ -35,68 +35,6 @@ namespace local_engine { -class BranchOutputTransform : public DB::IProcessor -{ -public: - using Status = DB::IProcessor::Status; - BranchOutputTransform(const DB::Block & header_) : DB::IProcessor({header_}, {header_}) { } - ~BranchOutputTransform() override = default; - - String getName() const override { return "BranchOutputTransform"; } - - Status prepare() override; - void work() override; - -private: - bool has_output = false; - DB::Chunk output_chunk; - bool has_input = false; - DB::Chunk input_chunk; -}; - -BranchOutputTransform::Status BranchOutputTransform::prepare() -{ - auto & output = outputs.front(); - auto & input = inputs.front(); - if (output.isFinished()) - { - input.close(); - return Status::Finished; - } - if (has_output) - { - if (output.canPush()) - { - output.push(std::move(output_chunk)); - has_output = false; - } - return Status::PortFull; - } - if (has_input) - return Status::Ready; - if (input.isFinished()) - { - output.finish(); - return Status::Finished; - } - input.setNeeded(); - if (!input.hasData()) - return Status::NeedData; - input_chunk = input.pull(true); - has_input = true; - return Status::Ready; -} - -void BranchOutputTransform::work() -{ - if (has_input) - { - output_chunk = std::move(input_chunk); - has_output = true; - has_input = false; - } -} - class BranchHookSource : public DB::IProcessor { public: @@ -240,13 +178,6 @@ void StaticBranchStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, co auto branch_transform = std::make_shared(header, max_sample_rows, branches, selector); DB::connect(*output, branch_transform->getInputs().front()); new_processors.push_back(branch_transform); - - for (auto & branch_output : branch_transform->getOutputs()) - { - auto branch_processor = std::make_shared(header); - DB::connect(branch_output, branch_processor->getInputs().front()); - new_processors.push_back(branch_processor); - } } return new_processors; }; @@ -278,6 +209,14 @@ void UniteBranchesStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, c { DB::Processors new_processors; size_t branch_index = 0; + if (child_outputs.size() != branch_plans.size()) + { + throw DB::Exception( + DB::ErrorCodes::LOGICAL_ERROR, + "Output port's size({}) is not equal to branches size({})", + child_outputs.size(), + branch_plans.size()); + } for (auto output : child_outputs) { auto & branch_plan = branch_plans[branch_index]; diff --git a/cpp-ch/local-engine/Operator/BranchStep.h b/cpp-ch/local-engine/Operator/BranchStep.h index 17056aa9b5a6..ddbd4c6fbb70 100644 --- a/cpp-ch/local-engine/Operator/BranchStep.h +++ b/cpp-ch/local-engine/Operator/BranchStep.h @@ -62,6 +62,8 @@ class StaticBranchStep : public DB::ITransformingStep BranchSelector selector; }; + +// It should be better to build execution branches on QueryPlan. class UniteBranchesStep : public DB::ITransformingStep { public: diff --git a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp index 81cc44cf9c08..d2264e24dc13 100644 --- a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp +++ b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp @@ -17,9 +17,6 @@ #include "WindowGroupLimitStep.h" -#include -#include -#include #include #include #include @@ -40,36 +37,18 @@ enum class WindowGroupLimitFunction DenseRank }; - -static DB::Block buildOutputHeader(const DB::Block & input_header, bool need_output_rank_values) -{ - if (!need_output_rank_values) - return input_header; - DB::Block output_header = input_header; - auto type = std::make_shared(); - auto col = type->createColumn(); - output_header.insert(DB::ColumnWithTypeAndName(std::move(col), type, "rank_value")); - return output_header; -} - template class WindowGroupLimitTransform : public DB::IProcessor { public: using Status = DB::IProcessor::Status; explicit WindowGroupLimitTransform( - const DB::Block & header_, - const std::vector & partition_columns_, - const std::vector & sort_columns_, - size_t limit_, - bool need_output_rank_values_ = false) - : DB::IProcessor({header_}, {buildOutputHeader(header_, need_output_rank_values_)}) + const DB::Block & header_, const std::vector & partition_columns_, const std::vector & sort_columns_, size_t limit_) + : DB::IProcessor({header_}, {header_}) , header(header_) , partition_columns(partition_columns_) , sort_columns(sort_columns_) , limit(limit_) - , need_output_rank_values(need_output_rank_values_) - { } ~WindowGroupLimitTransform() override = default; @@ -136,11 +115,6 @@ class WindowGroupLimitTransform : public DB::IProcessor if (!output_columns.empty() && output_columns[0]->size() > 0) { auto rows = output_columns[0]->size(); - if (rank_value_column) - { - output_columns.push_back(std::move(rank_value_column)); - rank_value_column.reset(); - } output_chunk = DB::Chunk(std::move(output_columns), rows); output_columns.clear(); has_output = true; @@ -156,13 +130,11 @@ class WindowGroupLimitTransform : public DB::IProcessor std::vector sort_columns; // Limitations for each partition. size_t limit = 0; - bool need_output_rank_values; bool has_input = false; DB::Chunk input_chunk; bool has_output = false; DB::MutableColumns output_columns; - DB::MutableColumnPtr rank_value_column = nullptr; DB::Chunk output_chunk; // We don't have window frame here. in fact all of frame are (unbounded preceding, current row] @@ -175,13 +147,6 @@ class WindowGroupLimitTransform : public DB::IProcessor DB::Columns partition_start_row_columns; DB::Columns peer_group_start_row_columns; - - void tryCreateRankValueColumn() - { - if (!rank_value_column) - rank_value_column = DB::DataTypeInt32().createColumn(); - } - size_t advanceNextPartition(const DB::Chunk & chunk, size_t start_offset) { if (partition_start_row_columns.empty()) @@ -265,12 +230,6 @@ class WindowGroupLimitTransform : public DB::IProcessor size_t rows = end_offset - start_offset; size_t limit_remained = limit - current_row_rank_value + 1; rows = rows > limit_remained ? limit_remained : rows; - if (need_output_rank_values) - { - tryCreateRankValueColumn(); - for (Int32 i = 0; i < static_cast(rows); ++i) - typeid_cast *>(rank_value_column.get())->insertValue(current_row_rank_value + i); - } insertResultValue(chunk, start_offset, rows); current_row_rank_value += rows; @@ -282,11 +241,6 @@ class WindowGroupLimitTransform : public DB::IProcessor { auto next_peer_group_start_offset = advanceNextPeerGroup(chunk, peer_group_start_offset, end_offset); size_t group_rows = next_peer_group_start_offset - peer_group_start_offset; - if (need_output_rank_values) - { - tryCreateRankValueColumn(); - rank_value_column->insertMany(current_row_rank_value, group_rows); - } insertResultValue(chunk, peer_group_start_offset, group_rows); try_end_peer_group(peer_group_start_offset, next_peer_group_start_offset, end_offset, chunk_rows); peer_group_start_offset = next_peer_group_start_offset; @@ -335,14 +289,12 @@ WindowGroupLimitStep::WindowGroupLimitStep( const String & function_name_, const std::vector & partition_columns_, const std::vector & sort_columns_, - size_t limit_, - bool need_output_rank_values_) - : DB::ITransformingStep(input_header_, buildOutputHeader(input_header_, need_output_rank_values_), getTraits()) + size_t limit_) + : DB::ITransformingStep(input_header_, input_header_, getTraits()) , function_name(function_name_) , partition_columns(partition_columns_) , sort_columns(sort_columns_) , limit(limit_) - , need_output_rank_values(need_output_rank_values_) { } @@ -366,7 +318,7 @@ void WindowGroupLimitStep::transformPipeline(DB::QueryPipelineBuilder & pipeline [&](const DB::Block & header) { return std::make_shared>( - header, partition_columns, sort_columns, limit, need_output_rank_values); + header, partition_columns, sort_columns, limit); }); } else if (function_name == "rank") diff --git a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h index e7658592a682..55e3eaeb72d3 100644 --- a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h +++ b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h @@ -31,8 +31,7 @@ class WindowGroupLimitStep : public DB::ITransformingStep const String & function_name_, const std::vector & partition_columns_, const std::vector & sort_columns_, - size_t limit_, - bool need_output_rank_values_ = false); + size_t limit_); ~WindowGroupLimitStep() override = default; String getName() const override { return "WindowGroupLimitStep"; } @@ -47,7 +46,6 @@ class WindowGroupLimitStep : public DB::ITransformingStep std::vector partition_columns; std::vector sort_columns; size_t limit; - bool need_output_rank_values; }; } diff --git a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp index 1f9b8f759be8..06f68e8ae218 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -38,10 +39,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -51,6 +54,7 @@ #include #include #include +#include #include namespace DB::ErrorCodes @@ -184,6 +188,7 @@ size_t selectBranchOnPartitionKeysCardinality( DB::QueryPlanPtr AggregateGroupLimitRelParser::parse( DB::QueryPlanPtr current_plan_, const substrait::Rel & rel, std::list & rel_stack_) { + // calculate window's topk by aggregation. // 1. add a pre-projecttion. Make two tuple arguments for the aggregation function. One is the required columns for the output, the other // is the required columns for sorting. // 2. Collect the sorting directions for each sorting field, Let them as the aggregation function's parameters. @@ -206,6 +211,9 @@ DB::QueryPlanPtr AggregateGroupLimitRelParser::parse( auto win_config = WindowConfig::loadFromContext(getContext()); auto high_card_threshold = win_config.aggregate_topk_high_cardinality_threshold; + // Aggregation doesn't perform well on high cardinality keys. We make two execution pathes here. + // - if the partition keys are low cardinality, run it by aggregation + // - if the partition keys are high cardinality, run it by window. auto partition_fields = parsePartitionFields(win_rel_def->partition_expressions()); auto branch_in_header = current_plan->getCurrentHeader(); auto branch_step = std::make_unique( @@ -226,8 +234,6 @@ DB::QueryPlanPtr AggregateGroupLimitRelParser::parse( postProjectionForExplodingArrays(*aggregation_plan); LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Aggregate topk plan:\n{}", PlanUtil::explainPlan(*aggregation_plan)); - - // aggregation doesn't performs well on high cardinality keys. use sort + window to get the topks. auto window_plan = BranchStepHelper::createSubPlan(branch_in_header, 1); addSortStep(*window_plan); addWindowLimitStep(*window_plan); @@ -259,7 +265,7 @@ String AggregateGroupLimitRelParser::getAggregateFunctionName(const String & win throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unsupported window function: {}", window_function_name); } -// Build two tuple columns as the aggregate function's arguments +// Build one tuple column as the aggregate function's arguments void AggregateGroupLimitRelParser::prePrejectionForAggregateArguments(DB::QueryPlan & plan) { auto projection_actions = std::make_shared(input_header.getColumnsWithTypeAndName()); @@ -322,36 +328,6 @@ void AggregateGroupLimitRelParser::prePrejectionForAggregateArguments(DB::QueryP plan.addStep(std::move(expression_step)); } - -String AggregateGroupLimitRelParser::parseSortDirections(const google::protobuf::RepeatedPtrField & sort_fields) -{ - DB::Array directions; - static const std::unordered_map order_directions - = {{1, " asc nulls first"}, {2, " asc nulls last"}, {3, " desc nulls first"}, {4, " desc nulls last"}}; - size_t n = 0; - DB::WriteBufferFromOwnString ostr; - for (const auto & sort_field : sort_fields) - { - auto it = order_directions.find(sort_field.direction()); - if (it == order_directions.end()) - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow sort direction: {}", sort_field.direction()); - if (!sort_field.expr().has_selection()) - { - throw DB::Exception( - DB::ErrorCodes::BAD_ARGUMENTS, "Sort field must be a column reference. but got {}", sort_field.DebugString()); - } - auto ref = sort_field.expr().selection().direct_reference().struct_field().field(); - const auto & col_name = input_header.getByPosition(ref).name; - if (n) - ostr << String(","); - // the col_name may contain '#' which can may ch fail to parse. - ostr << "`" << col_name << "`" << it->second; - n += 1; - } - LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Order by clasue: {}", ostr.str()); - return ostr.str(); -} - DB::AggregateDescription AggregateGroupLimitRelParser::buildAggregateDescription(DB::QueryPlan & plan) { DB::AggregateDescription agg_desc; @@ -359,7 +335,7 @@ DB::AggregateDescription AggregateGroupLimitRelParser::buildAggregateDescription agg_desc.argument_names = {aggregate_tuple_column_name}; DB::Array parameters; parameters.push_back(static_cast(limit)); - auto sort_directions = parseSortDirections(win_rel_def->sorts()); + auto sort_directions = buildSQLLikeSortDescription(input_header, win_rel_def->sorts()); parameters.push_back(sort_directions); auto header = plan.getCurrentHeader(); @@ -457,6 +433,39 @@ void AggregateGroupLimitRelParser::addSortStep(DB::QueryPlan & plan) plan.addStep(std::move(sorting_step)); } +static DB::WindowFrame buildWindowFrame(const std::string & ch_function_name) +{ + DB::WindowFrame frame; + // default window frame is [unbounded preceding, current row] + if (ch_function_name == "row_number") + { + frame.type = DB::WindowFrame::FrameType::ROWS; + frame.begin_type = DB::WindowFrame::BoundaryType::Offset; + frame.begin_offset = 1; + } + else + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow window function: {}", ch_function_name); + return frame; +} + +static DB::WindowFunctionDescription buildWindowFunctionDescription(const std::string & ch_function_name) +{ + DB::WindowFunctionDescription description; + if (ch_function_name == "row_number") + { + description.column_name = ch_function_name; + description.function_node = nullptr; + DB::AggregateFunctionProperties agg_props; + auto agg_func = RelParser::getAggregateFunction(ch_function_name, {}, agg_props, {}); + description.aggregate_function = agg_func; + } + else + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow window function: {}", ch_function_name); + return description; +} + + +// TODO: WindowGroupLimitStep has bad performance, need to improve it. So we still use window + filter here. void AggregateGroupLimitRelParser::addWindowLimitStep(DB::QueryPlan & plan) { google::protobuf::StringValue optimize_info_str; @@ -464,14 +473,35 @@ void AggregateGroupLimitRelParser::addWindowLimitStep(DB::QueryPlan & plan) auto optimization_info = WindowGroupOptimizationInfo::parse(optimize_info_str.value()); auto window_function_name = optimization_info.window_function; - auto partition_fields = parsePartitionFields(win_rel_def->partition_expressions()); - auto sort_fields = parseSortFields(win_rel_def->sorts()); - size_t limit = static_cast(win_rel_def->limit()); - - auto window_group_limit_step - = std::make_unique(plan.getCurrentHeader(), window_function_name, partition_fields, sort_fields, limit, true); - window_group_limit_step->setStepDescription("Window group limit"); - plan.addStep(std::move(window_group_limit_step)); + auto in_header = plan.getCurrentHeader(); + DB::WindowDescription win_descr; + win_descr.frame = buildWindowFrame(window_function_name); + win_descr.partition_by = parseSortFields(in_header, win_rel_def->partition_expressions()); + win_descr.order_by = parseSortFields(in_header, win_rel_def->sorts()); + win_descr.full_sort_description = win_descr.partition_by; + win_descr.full_sort_description.insert(win_descr.full_sort_description.end(), win_descr.order_by.begin(), win_descr.order_by.end()); + DB::WriteBufferFromOwnString ss; + ss << "partition by " << DB::dumpSortDescription(win_descr.partition_by); + ss << " order by " << DB::dumpSortDescription(win_descr.order_by); + ss << " " << win_descr.frame.toString(); + win_descr.window_name = ss.str(); + + auto win_func_description = buildWindowFunctionDescription(window_function_name); + win_descr.window_functions.push_back(win_func_description); + + auto win_step = std::make_unique(in_header, win_descr, win_descr.window_functions, false); + win_step->setStepDescription("Window (" + win_descr.window_name + ")"); + plan.addStep(std::move(win_step)); + + auto win_result_header = plan.getCurrentHeader(); + DB::ActionsDAG limit_actions_dag(win_result_header.getColumnsWithTypeAndName()); + const auto * rank_value_node = limit_actions_dag.getInputs().back(); + const auto * limit_value_node = expression_parser->addConstColumn(limit_actions_dag, std::make_shared(), limit); + const auto * cmp_node = expression_parser->toFunctionNode(limit_actions_dag, "lessOrEquals", {rank_value_node, limit_value_node}); + auto cmp_column_name = cmp_node->result_name; + limit_actions_dag.addOrReplaceInOutputs(*cmp_node); + auto filter_step = std::make_unique(win_result_header, std::move(limit_actions_dag), cmp_column_name, true); + plan.addStep(std::move(filter_step)); } void registerWindowGroupLimitRelParser(RelParserFactory & factory)