From fd56ea416460d0eb5cca0f42aa9538cc460c669a Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Tue, 3 Sep 2024 13:57:00 +0800 Subject: [PATCH 1/3] support WindowGroupLimit --- .../backendsapi/clickhouse/CHBackend.scala | 2 + cpp-ch/local-engine/Parser/RelParser.cpp | 14 +- .../Parser/SerializedPlanParser.cpp | 29 ++- .../Parser/WindowGroupLimitRelParser.cpp | 176 ++++++++++++++++++ .../Parser/WindowGroupLimitRelParser.h | 57 ++++++ 5 files changed, 256 insertions(+), 22 deletions(-) create mode 100644 cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp create mode 100644 cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 69ea899c42a5..45aee4322611 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -418,4 +418,6 @@ object CHBackendSettings extends BackendSettingsApi with Logging { } } } + + override def supportWindowGroupLimitExec(rankLikeFunction: Expression): Boolean = true } diff --git a/cpp-ch/local-engine/Parser/RelParser.cpp b/cpp-ch/local-engine/Parser/RelParser.cpp index f651146a391d..a7f6d0586455 100644 --- a/cpp-ch/local-engine/Parser/RelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParser.cpp @@ -30,8 +30,8 @@ namespace DB { namespace ErrorCodes { - extern const int BAD_ARGUMENTS; - extern const int LOGICAL_ERROR; +extern const int BAD_ARGUMENTS; +extern const int LOGICAL_ERROR; } } @@ -89,14 +89,15 @@ DB::QueryPlanPtr RelParser::parseOp(const substrait::Rel & rel, std::list RelParser::parseFormattedRelAdvancedOptimization(const substrait::extensions::AdvancedExtension &advanced_extension) +std::map +RelParser::parseFormattedRelAdvancedOptimization(const substrait::extensions::AdvancedExtension & advanced_extension) { std::map configs; if (advanced_extension.has_optimization()) { google::protobuf::StringValue msg; advanced_extension.optimization().UnpackTo(&msg); - Poco::StringTokenizer kvs( msg.value(), "\n"); + Poco::StringTokenizer kvs(msg.value(), "\n"); for (auto & kv : kvs) { if (kv.empty()) @@ -114,7 +115,8 @@ std::map RelParser::parseFormattedRelAdvancedOptimizat return configs; } -std::string RelParser::getStringConfig(const std::map & configs, const std::string & key, const std::string & default_value) +std::string +RelParser::getStringConfig(const std::map & configs, const std::string & key, const std::string & default_value) { auto it = configs.find(key); if (it == configs.end()) @@ -150,6 +152,7 @@ RelParserFactory::RelParserBuilder RelParserFactory::getBuilder(UInt32 k) } void registerWindowRelParser(RelParserFactory & factory); +void registerWindowGroupLimitRelParser(RelParserFactory & factory); void registerSortRelParser(RelParserFactory & factory); void registerExpandRelParser(RelParserFactory & factory); void registerAggregateParser(RelParserFactory & factory); @@ -162,6 +165,7 @@ void registerRelParsers() { auto & factory = RelParserFactory::instance(); registerWindowRelParser(factory); + registerWindowGroupLimitRelParser(factory); registerSortRelParser(factory); registerExpandRelParser(factory); registerAggregateParser(factory); diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 8efbd97d240d..fc0650993766 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -389,9 +389,9 @@ void adjustOutput(const DB::QueryPlanPtr & query_plan, const substrait::PlanRel } if (need_final_project) { - ActionsDAG final_project - = ActionsDAG::makeConvertingActions(original_cols, final_cols, ActionsDAG::MatchColumnsMode::Position); - QueryPlanStepPtr final_project_step = std::make_unique(query_plan->getCurrentDataStream(), std::move(final_project)); + ActionsDAG final_project = ActionsDAG::makeConvertingActions(original_cols, final_cols, ActionsDAG::MatchColumnsMode::Position); + QueryPlanStepPtr final_project_step + = std::make_unique(query_plan->getCurrentDataStream(), std::move(final_project)); final_project_step->setStepDescription("Project for output schema"); query_plan->addStep(std::move(final_project_step)); } @@ -499,6 +499,7 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel, std::list case substrait::Rel::RelTypeCase::kWindow: case substrait::Rel::RelTypeCase::kJoin: case substrait::Rel::RelTypeCase::kCross: + case substrait::Rel::RelTypeCase::kWindowGroupLimit: case substrait::Rel::RelTypeCase::kExpand: { auto op_parser = RelParserFactory::instance().getBuilder(rel.rel_type_case())(this); query_plan = op_parser->parseOp(rel, rel_stack); @@ -601,7 +602,7 @@ void SerializedPlanParser::parseArrayJoinArguments( } ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( - const substrait::Expression & rel, std::vector & result_names, ActionsDAG& actions_dag, bool keep_result, bool position) + const substrait::Expression & rel, std::vector & result_names, ActionsDAG & actions_dag, bool keep_result, bool position) { if (!rel.has_scalar_function()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "The root of expression should be a scalar function:\n {}", rel.DebugString()); @@ -718,7 +719,7 @@ ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( } const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( - const substrait::Expression & rel, std::string & result_name, ActionsDAG& actions_dag, bool keep_result) + const substrait::Expression & rel, std::string & result_name, ActionsDAG & actions_dag, bool keep_result) { if (!rel.has_scalar_function()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "the root of expression should be a scalar function:\n {}", rel.DebugString()); @@ -780,9 +781,8 @@ bool SerializedPlanParser::isFunction(substrait::Expression_ScalarFunction rel, } void SerializedPlanParser::parseFunctionOrExpression( - const substrait::Expression & rel, std::string & result_name, ActionsDAG& actions_dag, bool keep_result) + const substrait::Expression & rel, std::string & result_name, ActionsDAG & actions_dag, bool keep_result) { - if (rel.has_scalar_function()) parseFunctionWithDAG(rel, result_name, actions_dag, keep_result); else @@ -793,11 +793,7 @@ void SerializedPlanParser::parseFunctionOrExpression( } void SerializedPlanParser::parseJsonTuple( - const substrait::Expression & rel, - std::vector & result_names, - ActionsDAG& actions_dag, - bool keep_result, - bool) + const substrait::Expression & rel, std::vector & result_names, ActionsDAG & actions_dag, bool keep_result, bool) { const auto & scalar_function = rel.scalar_function(); auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference())); @@ -856,7 +852,7 @@ void SerializedPlanParser::parseJsonTuple( } const ActionsDAG::Node * -SerializedPlanParser::toFunctionNode(ActionsDAG& actions_dag, const String & function, const ActionsDAG::NodeRawConstPtrs & args) +SerializedPlanParser::toFunctionNode(ActionsDAG & actions_dag, const String & function, const ActionsDAG::NodeRawConstPtrs & args) { auto function_builder = FunctionFactory::instance().get(function, context); std::string args_name = join(args, ','); @@ -1068,7 +1064,7 @@ std::pair SerializedPlanParser::parseLiteral(const substrait return std::make_pair(std::move(type), std::move(field)); } -const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAG& actions_dag, const substrait::Expression & rel) +const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAG & actions_dag, const substrait::Expression & rel) { switch (rel.rex_type_case()) { @@ -1533,8 +1529,7 @@ ASTPtr ASTParser::parseArgumentToAST(const Names & names, const substrait::Expre } } -void SerializedPlanParser::removeNullableForRequiredColumns( - const std::set & require_columns, ActionsDAG & actions_dag) const +void SerializedPlanParser::removeNullableForRequiredColumns(const std::set & require_columns, ActionsDAG & actions_dag) const { for (const auto & item : require_columns) { @@ -1549,7 +1544,7 @@ void SerializedPlanParser::removeNullableForRequiredColumns( } void SerializedPlanParser::wrapNullable( - const std::vector & columns, ActionsDAG& actions_dag, std::map & nullable_measure_names) + const std::vector & columns, ActionsDAG & actions_dag, std::map & nullable_measure_names) { for (const auto & item : columns) { diff --git a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp new file mode 100644 index 000000000000..b1dd98282659 --- /dev/null +++ b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} + +const static String FUNCTION_ROW_NUM = "row_number"; +const static String FUNCTION_RANK = "top_rank"; +const static String FUNCTION_DENSE_RANK = "top_dense_rank"; + +namespace local_engine +{ +WindowGroupLimitRelParser::WindowGroupLimitRelParser(SerializedPlanParser * plan_parser_) : RelParser(plan_parser_) +{ + LOG_ERROR(getLogger("WindowGroupLimitRelParser"), "xxx new parrser"); +} +DB::QueryPlanPtr +WindowGroupLimitRelParser::parse(DB::QueryPlanPtr current_plan_, const substrait::Rel & rel, std::list & rel_stack_) +{ + const auto win_rel_def = rel.windowgrouplimit(); + current_plan = std::move(current_plan_); + + DB::Block output_header = current_plan->getCurrentDataStream().header; + + window_function_name = FUNCTION_ROW_NUM; + LOG_ERROR(getLogger("WindowGroupLimitRelParser"), "xxx input header: {}", current_plan->getCurrentDataStream().header.dumpStructure()); + + /// Only one window function in one window group limit + auto win_desc = buildWindowDescription(win_rel_def); + + auto win_step = std::make_unique(current_plan->getCurrentDataStream(), win_desc, win_desc.window_functions, false); + win_step->setStepDescription("Window Group Limit " + win_desc.window_name); + steps.emplace_back(win_step.get()); + current_plan->addStep(std::move(win_step)); + + /// remove the window function result column which is not needed in later steps + DB::ActionsDAG post_project_actions_dag = DB::ActionsDAG::makeConvertingActions( + current_plan->getCurrentDataStream().header.getColumnsWithTypeAndName(), + output_header.getColumnsWithTypeAndName(), + DB::ActionsDAG::MatchColumnsMode::Name); + auto post_project_step + = std::make_unique(current_plan->getCurrentDataStream(), std::move(post_project_actions_dag)); + post_project_step->setStepDescription("Window group limit: drop window function result column"); + steps.emplace_back(post_project_step.get()); + current_plan->addStep(std::move(post_project_step)); + + LOG_ERROR(getLogger("WindowGroupLimitRelParser"), "xxx output header: {}", current_plan->getCurrentDataStream().header.dumpStructure()); + bool x = true; + if (x) + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Invalide rel"); + } + return std::move(current_plan); +} + +DB::WindowFrame WindowGroupLimitRelParser::buildWindowFrame(const String & function_name) +{ + DB::WindowFrame frame; + if (function_name == FUNCTION_ROW_NUM) + { + frame.type = DB::WindowFrame::FrameType::ROWS; + frame.begin_type = DB::WindowFrame::BoundaryType::Offset; + frame.begin_offset = 1; + frame.begin_preceding = true; + frame.end_type = DB::WindowFrame::BoundaryType::Current; + frame.end_offset = 0; + frame.end_preceding = true; + } + else if (function_name == FUNCTION_RANK || function_name == FUNCTION_DENSE_RANK) + { + // rank and dense_rank can only work on range mode + frame.type = DB::WindowFrame::FrameType::RANGE; + frame.begin_type = DB::WindowFrame::BoundaryType::Unbounded; + frame.begin_offset = 0; + frame.begin_preceding = true; + frame.end_type = DB::WindowFrame::BoundaryType::Current; + frame.end_offset = 0; + frame.end_preceding = true; + } + else + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown function {} for window group limit", function_name); + } + + return frame; +} + +DB::WindowDescription WindowGroupLimitRelParser::buildWindowDescription(const substrait::WindowGroupLimitRel & win_rel_def) +{ + DB::WindowDescription win_desc; + win_desc.frame = buildWindowFrame(window_function_name); + win_desc.partition_by = parsePartitionBy(win_rel_def.partition_expressions()); + win_desc.order_by = SortRelParser::parseSortDescription(win_rel_def.sorts(), current_plan->getCurrentDataStream().header); + win_desc.full_sort_description = win_desc.partition_by; + win_desc.full_sort_description.insert(win_desc.full_sort_description.end(), win_desc.order_by.begin(), win_desc.order_by.end()); + + DB::WriteBufferFromOwnString ss; + ss << "partition by " << DB::dumpSortDescription(win_desc.partition_by); + ss << "order by " << DB::dumpSortDescription(win_desc.order_by); + ss << win_desc.frame.toString(); + win_desc.window_name = ss.str(); + + win_desc.window_functions.emplace_back(buildWindowFunctionDescription(window_function_name)); + + return win_desc; +} + +DB::SortDescription +WindowGroupLimitRelParser::parsePartitionBy(const google::protobuf::RepeatedPtrField & expressions) +{ + DB::Block header = current_plan->getCurrentDataStream().header; + DB::SortDescription sort_desc; + for (const auto & expr : expressions) + { + if (expr.has_selection()) + { + auto pos = expr.selection().direct_reference().struct_field().field(); + auto col_name = header.getByPosition(pos).name; + sort_desc.push_back(DB::SortColumnDescription(col_name, 1, 1)); + } + else if (expr.has_literal()) + { + continue; + } + else + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow partition argument: {}", expr.DebugString()); + } + } + return sort_desc; +} + +DB::WindowFunctionDescription WindowGroupLimitRelParser::buildWindowFunctionDescription(const String & function_name) +{ + DB::WindowFunctionDescription desc; + desc.column_name = function_name; + desc.function_node = nullptr; + DB::AggregateFunctionProperties func_properties; + DB::Names func_args; + DB::DataTypes func_args_types; + DB::Array func_params; + auto func_ptr = RelParser::getAggregateFunction(function_name, func_args_types, func_properties, func_params); + desc.argument_names = func_args; + desc.argument_types = func_args_types; + desc.aggregate_function = func_ptr; + return desc; +} +void registerWindowGroupLimitRelParser(RelParserFactory & factory) +{ + auto builder = [](SerializedPlanParser * plan_parser) { return std::make_shared(plan_parser); }; + factory.registerBuilder(substrait::Rel::RelTypeCase::kWindowGroupLimit, builder); +} +} diff --git a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h new file mode 100644 index 000000000000..38ffee0a40b2 --- /dev/null +++ b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ +/// Similar to WindowRelParser. Some differences +/// 1. cannot support aggregate functions. only support window functions: row_number, rank, dense_rank +/// 2. row_number, rank and dense_rank are mapped to new variants +/// 3. the output columns don't contain window function results +class WindowGroupLimitRelParser : public RelParser +{ +public: + explicit WindowGroupLimitRelParser(SerializedPlanParser * plan_parser_); + ~WindowGroupLimitRelParser() override = default; + DB::QueryPlanPtr + parse(DB::QueryPlanPtr current_plan_, const substrait::Rel & rel, std::list & rel_stack_) override; + const substrait::Rel & getSingleInput(const substrait::Rel & rel) override { return rel.windowgrouplimit().input(); } + +private: + DB::QueryPlanPtr current_plan; + String window_function_name; + + DB::WindowDescription buildWindowDescription(const substrait::WindowGroupLimitRel & win_rel_def); + /// There is only one type of window frame at present. + static DB::WindowFrame buildWindowFrame(const String & function_name); + + DB::SortDescription parsePartitionBy(const google::protobuf::RepeatedPtrField & expressions); + + static DB::WindowFunctionDescription buildWindowFunctionDescription(const String & function_name); +}; +} From 1950ac44bec3fd554dbf7e34088d857981217d36 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Tue, 3 Sep 2024 17:21:49 +0800 Subject: [PATCH 2/3] 0903 --- .../WindowGroupLimitFunctions.cpp | 92 +++++++++++++++++++ .../WindowGroupLimitFunctions.h | 33 +++++++ cpp-ch/local-engine/Common/CHUtil.cpp | 41 +++++---- .../Parser/WindowGroupLimitRelParser.cpp | 17 ++-- .../Parser/WindowGroupLimitRelParser.h | 2 +- 5 files changed, 156 insertions(+), 29 deletions(-) create mode 100644 cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.cpp create mode 100644 cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.h diff --git a/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.cpp b/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.cpp new file mode 100644 index 000000000000..57232b7ecf59 --- /dev/null +++ b/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.cpp @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} + +namespace local_engine +{ +WindowFunctionTopRowNumber::WindowFunctionTopRowNumber(const String name, const DB::DataTypes & arg_types, const DB::Array & parameters_) + : DB::WindowFunction(name, arg_types, parameters_, std::make_shared()) +{ + if (parameters.size() != 1) + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "{} needs a limit parameter", name); + limit = parameters[0].safeGet(); + LOG_ERROR(getLogger("WindowFunctionTopRowNumber"), "xxx {} limit: {}", name, limit); +} + +void WindowFunctionTopRowNumber::windowInsertResultInto(const DB::WindowTransform * transform, size_t function_index) const +{ + LOG_ERROR( + getLogger("WindowFunctionTopRowNumber"), + "xxx current row number: {}, current_row: {}@{}, partition_ended: {}", + transform->current_row_number, + transform->current_row.block, + transform->current_row.row, + transform->partition_ended); + /// If the rank value is larger then limit, and current block only contains rows which are all belong to one partition. + /// We cant drop this block directly. + if (!transform->partition_ended && !transform->current_row.row && transform->current_row_number > limit) + { + /// It's safe to make it mutable here. but it's still too dangerous, it may be changed in the future and make it unsafe. + auto * mutable_transform = const_cast(transform); + DB::WindowTransformBlock & current_block = mutable_transform->blockAt(mutable_transform->current_row); + current_block.rows = 0; + auto clear_columns = [](DB::Columns & cols) + { + DB::Columns new_cols; + for (const auto & col : cols) + { + new_cols.push_back(std::move(col->cloneEmpty())); + } + cols = new_cols; + }; + clear_columns(current_block.original_input_columns); + clear_columns(current_block.input_columns); + clear_columns(current_block.casted_columns); + mutable_transform->current_row.block += 1; + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "{} is not implemented", name); + } + else + { + auto & to_col = *transform->blockAt(transform->current_row).output_columns[function_index]; + assert_cast(to_col).getData().push_back(transform->current_row_number); + } +} + +void registerWindowGroupLimitFunctions(DB::AggregateFunctionFactory & factory) +{ + const DB::AggregateFunctionProperties properties + = {.returns_default_when_only_null = true, .is_order_dependent = true, .is_window_function = true}; + factory.registerFunction( + "top_row_number", + {[](const String & name, const DB::DataTypes & args_type, const DB::Array & parameters, const DB::Settings *) + { return std::make_shared(name, args_type, parameters); }, + properties}, + DB::AggregateFunctionFactory::Case::Insensitive); +} +} diff --git a/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.h b/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.h new file mode 100644 index 000000000000..6c5cc19458d3 --- /dev/null +++ b/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.h @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +namespace local_engine +{ +class WindowFunctionTopRowNumber : public DB::WindowFunction +{ +public: + explicit WindowFunctionTopRowNumber(const String name, const DB::DataTypes & arg_types_, const DB::Array & parameters_); + ~WindowFunctionTopRowNumber() override = default; + + void windowInsertResultInto(const DB::WindowTransform * transform, size_t function_index) const override; + bool allocatesMemoryInArena() const override { return false; } + +private: + size_t limit = 0; +}; +} diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index b702082d6ff3..7b1536c6e5f0 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -560,9 +560,7 @@ std::map BackendInitializerUtil::getBackendConfMap(std } std::vector BackendInitializerUtil::wrapDiskPathConfig( - const String & path_prefix, - const String & path_suffix, - Poco::Util::AbstractConfiguration & config) + const String & path_prefix, const String & path_suffix, Poco::Util::AbstractConfiguration & config) { std::vector changed_paths; if (path_prefix.empty() && path_suffix.empty()) @@ -657,9 +655,7 @@ DB::Context::ConfigurationPtr BackendInitializerUtil::initConfig(std::mapgetString("timezone"); const String mapped_timezone = DateTimeUtil::convertTimeZone(config_timezone); - if (0 != setenv("TZ", mapped_timezone.data(), 1)) // NOLINT(concurrency-mt-unsafe) // ok if not called concurrently with other setenv/getenv + if (0 + != setenv( + "TZ", mapped_timezone.data(), 1)) // NOLINT(concurrency-mt-unsafe) // ok if not called concurrently with other setenv/getenv throw Poco::Exception("Cannot setenv TZ variable"); tzset(); @@ -807,8 +805,7 @@ void BackendInitializerUtil::initSettings(std::map & b { auto mem_gb = task_memory / static_cast(1_GiB); // 2.8x+5, Heuristics calculate the block size of external sort, [8,16] - settings.prefer_external_sort_block_bytes = std::max(std::min( - static_cast(2.8*mem_gb + 5), 16ul), 8ul) * 1024 * 1024; + settings.prefer_external_sort_block_bytes = std::max(std::min(static_cast(2.8 * mem_gb + 5), 16ul), 8ul) * 1024 * 1024; } } } @@ -848,10 +845,14 @@ void BackendInitializerUtil::initContexts(DB::Context::ConfigurationPtr config) global_context->setMarkCache(mark_cache_policy, mark_cache_size, mark_cache_size_ratio); - String index_uncompressed_cache_policy = config->getString("index_uncompressed_cache_policy", DEFAULT_INDEX_UNCOMPRESSED_CACHE_POLICY); - size_t index_uncompressed_cache_size = config->getUInt64("index_uncompressed_cache_size", DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE); - double index_uncompressed_cache_size_ratio = config->getDouble("index_uncompressed_cache_size_ratio", DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO); - global_context->setIndexUncompressedCache(index_uncompressed_cache_policy, index_uncompressed_cache_size, index_uncompressed_cache_size_ratio); + String index_uncompressed_cache_policy + = config->getString("index_uncompressed_cache_policy", DEFAULT_INDEX_UNCOMPRESSED_CACHE_POLICY); + size_t index_uncompressed_cache_size + = config->getUInt64("index_uncompressed_cache_size", DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE); + double index_uncompressed_cache_size_ratio + = config->getDouble("index_uncompressed_cache_size_ratio", DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO); + global_context->setIndexUncompressedCache( + index_uncompressed_cache_policy, index_uncompressed_cache_size, index_uncompressed_cache_size_ratio); String index_mark_cache_policy = config->getString("index_mark_cache_policy", DEFAULT_INDEX_MARK_CACHE_POLICY); size_t index_mark_cache_size = config->getUInt64("index_mark_cache_size", DEFAULT_INDEX_MARK_CACHE_MAX_SIZE); @@ -890,6 +891,7 @@ extern void registerAggregateFunctionCombinatorPartialMerge(AggregateFunctionCom extern void registerAggregateFunctionsBloomFilter(AggregateFunctionFactory &); extern void registerAggregateFunctionSparkAvg(AggregateFunctionFactory &); extern void registerFunctions(FunctionFactory &); +extern void registerWindowGroupLimitFunctions(AggregateFunctionFactory &); void registerAllFunctions() { @@ -899,6 +901,7 @@ void registerAllFunctions() auto & agg_factory = AggregateFunctionFactory::instance(); registerAggregateFunctionsBloomFilter(agg_factory); registerAggregateFunctionSparkAvg(agg_factory); + registerWindowGroupLimitFunctions(agg_factory); { /// register aggregate function combinators from local_engine auto & factory = AggregateFunctionCombinatorFactory::instance(); @@ -1023,11 +1026,13 @@ void BackendFinalizerUtil::finalizeGlobally() StorageMergeTreeFactory::clear(); QueryContext::resetGlobal(); std::lock_guard lock(paths_mutex); - std::ranges::for_each(paths_need_to_clean, [](const auto & path) - { - if (fs::exists(path)) - fs::remove_all(path); - }); + std::ranges::for_each( + paths_need_to_clean, + [](const auto & path) + { + if (fs::exists(path)) + fs::remove_all(path); + }); paths_need_to_clean.clear(); } diff --git a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp index b1dd98282659..153918850ff9 100644 --- a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp +++ b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp @@ -27,7 +27,7 @@ namespace DB::ErrorCodes extern const int BAD_ARGUMENTS; } -const static String FUNCTION_ROW_NUM = "row_number"; +const static String FUNCTION_ROW_NUM = "top_row_number"; const static String FUNCTION_RANK = "top_rank"; const static String FUNCTION_DENSE_RANK = "top_dense_rank"; @@ -68,22 +68,18 @@ WindowGroupLimitRelParser::parse(DB::QueryPlanPtr current_plan_, const substrait current_plan->addStep(std::move(post_project_step)); LOG_ERROR(getLogger("WindowGroupLimitRelParser"), "xxx output header: {}", current_plan->getCurrentDataStream().header.dumpStructure()); - bool x = true; - if (x) - { - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Invalide rel"); - } return std::move(current_plan); } DB::WindowFrame WindowGroupLimitRelParser::buildWindowFrame(const String & function_name) { + // We only need first rows, so let the begin type is unbounded is OK DB::WindowFrame frame; if (function_name == FUNCTION_ROW_NUM) { frame.type = DB::WindowFrame::FrameType::ROWS; - frame.begin_type = DB::WindowFrame::BoundaryType::Offset; - frame.begin_offset = 1; + frame.begin_type = DB::WindowFrame::BoundaryType::Unbounded; + frame.begin_offset = 0; frame.begin_preceding = true; frame.end_type = DB::WindowFrame::BoundaryType::Current; frame.end_offset = 0; @@ -123,7 +119,7 @@ DB::WindowDescription WindowGroupLimitRelParser::buildWindowDescription(const su ss << win_desc.frame.toString(); win_desc.window_name = ss.str(); - win_desc.window_functions.emplace_back(buildWindowFunctionDescription(window_function_name)); + win_desc.window_functions.emplace_back(buildWindowFunctionDescription(window_function_name, static_cast(win_rel_def.limit()))); return win_desc; } @@ -153,7 +149,7 @@ WindowGroupLimitRelParser::parsePartitionBy(const google::protobuf::RepeatedPtrF return sort_desc; } -DB::WindowFunctionDescription WindowGroupLimitRelParser::buildWindowFunctionDescription(const String & function_name) +DB::WindowFunctionDescription WindowGroupLimitRelParser::buildWindowFunctionDescription(const String & function_name, size_t limit) { DB::WindowFunctionDescription desc; desc.column_name = function_name; @@ -162,6 +158,7 @@ DB::WindowFunctionDescription WindowGroupLimitRelParser::buildWindowFunctionDesc DB::Names func_args; DB::DataTypes func_args_types; DB::Array func_params; + func_params.push_back(limit); auto func_ptr = RelParser::getAggregateFunction(function_name, func_args_types, func_properties, func_params); desc.argument_names = func_args; desc.argument_types = func_args_types; diff --git a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h index 38ffee0a40b2..6b7d3bbf33e5 100644 --- a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h +++ b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h @@ -52,6 +52,6 @@ class WindowGroupLimitRelParser : public RelParser DB::SortDescription parsePartitionBy(const google::protobuf::RepeatedPtrField & expressions); - static DB::WindowFunctionDescription buildWindowFunctionDescription(const String & function_name); + static DB::WindowFunctionDescription buildWindowFunctionDescription(const String & function_name, size_t limit); }; } From 064bc5defaf707a1d6e5e74070dca48ea0d1db33 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Mon, 9 Sep 2024 15:55:24 +0800 Subject: [PATCH 3/3] implement window group limit --- .../clickhouse/CHSparkPlanExecApi.scala | 16 + .../CHWindowGroupLimitExecTransformer.scala | 187 +++++++++ .../GlutenClickHouseTPCDSAbstractSuite.scala | 2 +- ...enClickHouseTPCHSaltNullParquetSuite.scala | 6 +- .../WindowGroupLimitFunctions.cpp | 92 ----- .../WindowGroupLimitFunctions.h | 33 -- cpp-ch/local-engine/Common/CHUtil.cpp | 4 +- .../Operator/ReplicateRowsStep.cpp | 20 +- .../Operator/WindowGroupLimitStep.cpp | 365 ++++++++++++++++++ .../Operator/WindowGroupLimitStep.h | 51 +++ .../Parser/AdvancedParametersParseUtil.cpp | 31 +- .../Parser/AdvancedParametersParseUtil.h | 9 +- .../Parser/WindowGroupLimitRelParser.cpp | 149 +++---- .../Parser/WindowGroupLimitRelParser.h | 9 +- .../gluten/backendsapi/SparkPlanExecApi.scala | 10 + .../columnar/OffloadSingleNode.scala | 2 +- 16 files changed, 714 insertions(+), 272 deletions(-) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala delete mode 100644 cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.cpp delete mode 100644 cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.h create mode 100644 cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp create mode 100644 cpp-ch/local-engine/Operator/WindowGroupLimitStep.h diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 1108b8b3c501..f765a75d2f7d 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildSideRelation, HashedRelationBroadcastMode} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.utils.{CHExecUtil, PushDownUtil} +import org.apache.spark.sql.execution.window._ import org.apache.spark.sql.types.{DecimalType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -909,4 +910,19 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { toScale: Int): DecimalType = { SparkShimLoader.getSparkShims.genDecimalRoundExpressionOutput(decimalType, toScale) } + + override def genWindowGroupLimitTransformer( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + rankLikeFunction: Expression, + limit: Int, + mode: WindowGroupLimitMode, + child: SparkPlan): SparkPlan = + CHWindowGroupLimitExecTransformer( + partitionSpec, + orderSpec, + rankLikeFunction, + limit, + mode, + child) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala new file mode 100644 index 000000000000..c2648f29ec4c --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.exception.GlutenNotSupportException +import org.apache.gluten.expression._ +import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter} +import org.apache.gluten.extension.ValidationResult +import org.apache.gluten.metrics.MetricsUpdater +import org.apache.gluten.substrait.`type`.TypeBuilder +import org.apache.gluten.substrait.SubstraitContext +import org.apache.gluten.substrait.extensions.ExtensionBuilder +import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.window.{Final, Partial, WindowGroupLimitMode} + +import com.google.protobuf.StringValue +import io.substrait.proto.SortField + +import scala.collection.JavaConverters._ + +case class CHWindowGroupLimitExecTransformer( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + rankLikeFunction: Expression, + limit: Int, + mode: WindowGroupLimitMode, + child: SparkPlan) + extends UnaryTransformSupport { + + @transient override lazy val metrics = + BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetrics(sparkContext) + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) + + override def metricsUpdater(): MetricsUpdater = + BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetricsUpdater(metrics) + + override def output: Seq[Attribute] = child.output + + override def requiredChildDistribution: Seq[Distribution] = mode match { + case Partial => super.requiredChildDistribution + case Final => + if (partitionSpec.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(partitionSpec) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + if (BackendsApiManager.getSettings.requiredChildOrderingForWindowGroupLimit()) { + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + } else { + Seq(Nil) + } + } + + override def outputOrdering: Seq[SortOrder] = { + if (requiredChildOrdering.forall(_.isEmpty)) { + // The Velox backend `TopNRowNumber` does not require child ordering, because it + // uses hash table to store partition and use priority queue to track of top limit rows. + // Ideally, the output of `TopNRowNumber` is unordered but it is grouped for partition keys. + // To be safe, here we do not propagate the ordering. + // TODO: Make the framework aware of grouped data distribution + Nil + } else { + child.outputOrdering + } + } + + override def outputPartitioning: Partitioning = child.outputPartitioning + + def getWindowGroupLimitRel( + context: SubstraitContext, + originalInputAttributes: Seq[Attribute], + operatorId: Long, + input: RelNode, + validation: Boolean): RelNode = { + val args = context.registeredFunction + // Partition By Expressions + val partitionsExpressions = partitionSpec + .map( + ExpressionConverter + .replaceWithExpressionTransformer(_, attributeSeq = child.output) + .doTransform(args)) + .asJava + + // Sort By Expressions + val sortFieldList = + orderSpec.map { + order => + val builder = SortField.newBuilder() + val exprNode = ExpressionConverter + .replaceWithExpressionTransformer(order.child, attributeSeq = child.output) + .doTransform(args) + builder.setExpr(exprNode.toProtobuf) + builder.setDirectionValue(SortExecTransformer.transformSortDirection(order)) + builder.build() + }.asJava + if (!validation) { + val windowFunction = rankLikeFunction match { + case _: RowNumber => ExpressionNames.ROW_NUMBER + case _: Rank => ExpressionNames.RANK + case _: DenseRank => ExpressionNames.DENSE_RANK + case _ => throw new GlutenNotSupportException(s"Unknow window function $rankLikeFunction") + } + val parametersStr = new StringBuffer("WindowGroupLimitParameters:") + parametersStr + .append("window_function=") + .append(windowFunction) + .append("\n") + val message = StringValue.newBuilder().setValue(parametersStr.toString).build() + val extensionNode = ExtensionBuilder.makeAdvancedExtension( + BackendsApiManager.getTransformerApiInstance.packPBMessage(message), + null) + RelBuilder.makeWindowGroupLimitRel( + input, + partitionsExpressions, + sortFieldList, + limit, + extensionNode, + context, + operatorId) + } else { + // Use a extension node to send the input types through Substrait plan for validation. + val inputTypeNodeList = originalInputAttributes + .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) + .asJava + val extensionNode = ExtensionBuilder.makeAdvancedExtension( + BackendsApiManager.getTransformerApiInstance.packPBMessage( + TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) + + RelBuilder.makeWindowGroupLimitRel( + input, + partitionsExpressions, + sortFieldList, + limit, + extensionNode, + context, + operatorId) + } + } + + override protected def doValidateInternal(): ValidationResult = { + if (!BackendsApiManager.getSettings.supportWindowGroupLimitExec(rankLikeFunction)) { + return ValidationResult + .failed(s"Found unsupported rank like function: $rankLikeFunction") + } + val substraitContext = new SubstraitContext + val operatorId = substraitContext.nextOperatorId(this.nodeName) + + val relNode = + getWindowGroupLimitRel(substraitContext, child.output, operatorId, null, validation = true) + + doNativeValidation(substraitContext, relNode) + } + + override protected def doTransform(context: SubstraitContext): TransformContext = { + val childCtx = child.asInstanceOf[TransformSupport].transform(context) + val operatorId = context.nextOperatorId(this.nodeName) + + val currRel = + getWindowGroupLimitRel(context, child.output, operatorId, childCtx.root, validation = false) + assert(currRel != null, "Window Group Limit Rel should be valid") + TransformContext(childCtx.outputAttributes, output, currRel) + } +} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala index 03b26fa985ea..abb7d27ffe92 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala @@ -62,7 +62,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite }) protected def fallbackSets(isAqe: Boolean): Set[Int] = { - if (isSparkVersionGE("3.5")) Set(44, 67, 70) else Set.empty[Int] + Set.empty[Int] } protected def excludedTpcdsQueries: Set[String] = Set( "q66" // inconsistent results diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala index f7cf0de3762d..9ac35441acb4 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -1855,7 +1855,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | ) t1 |) t2 where rank = 1 """.stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }, isSparkVersionLE("3.3")) + compareResultsAgainstVanillaSpark(sql, true, { _ => }) } test("GLUTEN-1874 not null in both streams") { @@ -1873,7 +1873,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | ) t1 |) t2 where rank = 1 """.stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }, isSparkVersionLE("3.3")) + compareResultsAgainstVanillaSpark(sql, true, { _ => }) } test("GLUTEN-2095: test cast(string as binary)") { @@ -2456,7 +2456,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | ) t1 |) t2 where rank = 1 order by p_partkey limit 100 |""".stripMargin - runQueryAndCompare(sql, noFallBack = isSparkVersionLE("3.3"))({ _ => }) + runQueryAndCompare(sql, noFallBack = true)({ _ => }) } test("GLUTEN-4190: crush on flattening a const null column") { diff --git a/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.cpp b/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.cpp deleted file mode 100644 index 57232b7ecf59..000000000000 --- a/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.cpp +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include -#include -#include -#include -#include - -namespace DB::ErrorCodes -{ -extern const int BAD_ARGUMENTS; -} - -namespace local_engine -{ -WindowFunctionTopRowNumber::WindowFunctionTopRowNumber(const String name, const DB::DataTypes & arg_types, const DB::Array & parameters_) - : DB::WindowFunction(name, arg_types, parameters_, std::make_shared()) -{ - if (parameters.size() != 1) - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "{} needs a limit parameter", name); - limit = parameters[0].safeGet(); - LOG_ERROR(getLogger("WindowFunctionTopRowNumber"), "xxx {} limit: {}", name, limit); -} - -void WindowFunctionTopRowNumber::windowInsertResultInto(const DB::WindowTransform * transform, size_t function_index) const -{ - LOG_ERROR( - getLogger("WindowFunctionTopRowNumber"), - "xxx current row number: {}, current_row: {}@{}, partition_ended: {}", - transform->current_row_number, - transform->current_row.block, - transform->current_row.row, - transform->partition_ended); - /// If the rank value is larger then limit, and current block only contains rows which are all belong to one partition. - /// We cant drop this block directly. - if (!transform->partition_ended && !transform->current_row.row && transform->current_row_number > limit) - { - /// It's safe to make it mutable here. but it's still too dangerous, it may be changed in the future and make it unsafe. - auto * mutable_transform = const_cast(transform); - DB::WindowTransformBlock & current_block = mutable_transform->blockAt(mutable_transform->current_row); - current_block.rows = 0; - auto clear_columns = [](DB::Columns & cols) - { - DB::Columns new_cols; - for (const auto & col : cols) - { - new_cols.push_back(std::move(col->cloneEmpty())); - } - cols = new_cols; - }; - clear_columns(current_block.original_input_columns); - clear_columns(current_block.input_columns); - clear_columns(current_block.casted_columns); - mutable_transform->current_row.block += 1; - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "{} is not implemented", name); - } - else - { - auto & to_col = *transform->blockAt(transform->current_row).output_columns[function_index]; - assert_cast(to_col).getData().push_back(transform->current_row_number); - } -} - -void registerWindowGroupLimitFunctions(DB::AggregateFunctionFactory & factory) -{ - const DB::AggregateFunctionProperties properties - = {.returns_default_when_only_null = true, .is_order_dependent = true, .is_window_function = true}; - factory.registerFunction( - "top_row_number", - {[](const String & name, const DB::DataTypes & args_type, const DB::Array & parameters, const DB::Settings *) - { return std::make_shared(name, args_type, parameters); }, - properties}, - DB::AggregateFunctionFactory::Case::Insensitive); -} -} diff --git a/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.h b/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.h deleted file mode 100644 index 6c5cc19458d3..000000000000 --- a/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.h +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -namespace local_engine -{ -class WindowFunctionTopRowNumber : public DB::WindowFunction -{ -public: - explicit WindowFunctionTopRowNumber(const String name, const DB::DataTypes & arg_types_, const DB::Array & parameters_); - ~WindowFunctionTopRowNumber() override = default; - - void windowInsertResultInto(const DB::WindowTransform * transform, size_t function_index) const override; - bool allocatesMemoryInArena() const override { return false; } - -private: - size_t limit = 0; -}; -} diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 7b1536c6e5f0..94a214e5e571 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -53,6 +53,7 @@ #include #include #include +#include #include #include #include @@ -315,7 +316,6 @@ DB::Block BlockUtil::concatenateBlocksMemoryEfficiently(std::vector & return out; } - size_t PODArrayUtil::adjustMemoryEfficientSize(size_t n) { /// According to definition of DEFUALT_BLOCK_SIZE @@ -891,7 +891,6 @@ extern void registerAggregateFunctionCombinatorPartialMerge(AggregateFunctionCom extern void registerAggregateFunctionsBloomFilter(AggregateFunctionFactory &); extern void registerAggregateFunctionSparkAvg(AggregateFunctionFactory &); extern void registerFunctions(FunctionFactory &); -extern void registerWindowGroupLimitFunctions(AggregateFunctionFactory &); void registerAllFunctions() { @@ -901,7 +900,6 @@ void registerAllFunctions() auto & agg_factory = AggregateFunctionFactory::instance(); registerAggregateFunctionsBloomFilter(agg_factory); registerAggregateFunctionSparkAvg(agg_factory); - registerWindowGroupLimitFunctions(agg_factory); { /// register aggregate function combinators from local_engine auto & factory = AggregateFunctionCombinatorFactory::instance(); diff --git a/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp b/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp index f2d4bc8a865d..ecb027c18f0a 100644 --- a/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp +++ b/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp @@ -32,16 +32,14 @@ namespace local_engine { static DB::ITransformingStep::Traits getTraits() { - return DB::ITransformingStep::Traits - { + return DB::ITransformingStep::Traits{ { .preserves_number_of_streams = true, .preserves_sorting = false, }, { .preserves_number_of_rows = false, - } - }; + }}; } ReplicateRowsStep::ReplicateRowsStep(const DB::DataStream & input_stream) @@ -49,7 +47,7 @@ ReplicateRowsStep::ReplicateRowsStep(const DB::DataStream & input_stream) { } -DB::Block ReplicateRowsStep::transformHeader(const DB::Block& input) +DB::Block ReplicateRowsStep::transformHeader(const DB::Block & input) { DB::Block output; for (int i = 1; i < input.columns(); i++) @@ -59,15 +57,9 @@ DB::Block ReplicateRowsStep::transformHeader(const DB::Block& input) return output; } -void ReplicateRowsStep::transformPipeline( - DB::QueryPipelineBuilder & pipeline, - const DB::BuildQueryPipelineSettings & /*settings*/) +void ReplicateRowsStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & /*settings*/) { - pipeline.addSimpleTransform( - [&](const DB::Block & header) - { - return std::make_shared(header); - }); + pipeline.addSimpleTransform([&](const DB::Block & header) { return std::make_shared(header); }); } void ReplicateRowsStep::updateOutputStream() @@ -105,4 +97,4 @@ void ReplicateRowsTransform::transform(DB::Chunk & chunk) chunk.setColumns(std::move(mutable_columns), total_rows); } -} \ No newline at end of file +} diff --git a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp new file mode 100644 index 000000000000..af04ef579028 --- /dev/null +++ b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "WindowGroupLimitStep.h" +#include +#include +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ +extern const int LOGICAL_ERROR; +} + +namespace local_engine +{ + +enum class WindowGroupLimitFunction +{ + RowNumber, + Rank, + DenseRank +}; + + +template +class WindowGroupLimitTransform : public DB::IProcessor +{ +public: + using Status = DB::IProcessor::Status; + explicit WindowGroupLimitTransform( + const DB::Block & header_, const std::vector & partition_columns_, const std::vector & sort_columns_, size_t limit_) + : DB::IProcessor({header_}, {header_}) + , header(header_) + , partition_columns(partition_columns_) + , sort_columns(sort_columns_) + , limit(limit_) + + { + } + ~WindowGroupLimitTransform() override = default; + String getName() const override { return "WindowGroupLimitTransform"; } + + Status prepare() override + { + auto & output_port = outputs.front(); + auto & input_port = inputs.front(); + if (output_port.isFinished()) + { + input_port.close(); + return Status::Finished; + } + + if (has_output) + { + if (output_port.canPush()) + { + output_port.push(std::move(output_chunk)); + has_output = false; + } + return Status::PortFull; + } + + if (has_input) + return Status::Ready; + + if (input_port.isFinished()) + { + output_port.finish(); + return Status::Finished; + } + input_port.setNeeded(); + if (!input_port.hasData()) + return Status::NeedData; + input_chunk = input_port.pull(true); + has_input = true; + return Status::Ready; + } + + void work() override + { + if (!has_input) [[unlikely]] + { + return; + } + DB::Block block = header.cloneWithColumns(input_chunk.getColumns()); + size_t partition_start_row = 0; + size_t chunk_rows = input_chunk.getNumRows(); + while (partition_start_row < chunk_rows) + { + auto next_partition_start_row = advanceNextPartition(input_chunk, partition_start_row); + iteratePartition(input_chunk, partition_start_row, next_partition_start_row); + partition_start_row = next_partition_start_row; + // corner case, the partition end row is the last row of chunk. + if (partition_start_row < chunk_rows) + { + current_row_rank_value = 1; + if constexpr (function == WindowGroupLimitFunction::Rank) + current_peer_group_rows = 0; + partition_start_row_columns = extractOneRowColumns(input_chunk, partition_start_row); + } + } + + if (!output_columns.empty() && output_columns[0]->size() > 0) + { + auto rows = output_columns[0]->size(); + output_chunk = DB::Chunk(std::move(output_columns), rows); + output_columns.clear(); + has_output = true; + } + has_input = false; + } + +private: + DB::Block header; + // Which columns are used as the partition keys + std::vector partition_columns; + // which columns are used as the order by keys, excluding partition columns. + std::vector sort_columns; + // Limitations for each partition. + size_t limit = 0; + + bool has_input = false; + DB::Chunk input_chunk; + bool has_output = false; + DB::MutableColumns output_columns; + DB::Chunk output_chunk; + + // We don't have window frame here. in fact all of frame are (unbounded preceding, current row] + // the start value is 1 + size_t current_row_rank_value = 1; + // rank need this to record how many rows in current peer group. + // A peer group in a partition is defined as the rows have the same value on the sort columns. + size_t current_peer_group_rows = 0; + + DB::Columns partition_start_row_columns; + DB::Columns peer_group_start_row_columns; + + + size_t advanceNextPartition(const DB::Chunk & chunk, size_t start_offset) + { + if (partition_start_row_columns.empty()) + partition_start_row_columns = extractOneRowColumns(chunk, start_offset); + + size_t max_row = chunk.getNumRows(); + for (size_t i = start_offset; i < max_row; ++i) + { + if (!isRowEqual(partition_columns, partition_start_row_columns, 0, chunk.getColumns(), i)) + { + return i; + } + } + return max_row; + } + + static DB::Columns extractOneRowColumns(const DB::Chunk & chunk, size_t offset) + { + DB::Columns row; + for (const auto & col : chunk.getColumns()) + { + auto new_col = col->cloneEmpty(); + new_col->insertFrom(*col, offset); + row.push_back(std::move(new_col)); + } + return row; + } + + static bool isRowEqual( + const std::vector & fields, const DB::Columns & left_cols, size_t loffset, const DB::Columns & right_cols, size_t roffset) + { + for (size_t i = 0; i < fields.size(); ++i) + { + const auto & field = fields[i]; + /// don't care about nan_direction_hint + if (left_cols[field]->compareAt(loffset, roffset, *right_cols[field], 1)) + return false; + } + return true; + } + + void iteratePartition(const DB::Chunk & chunk, size_t start_offset, size_t end_offset) + { + // Skip the rest rows int this partition. + if (current_row_rank_value > limit) + return; + + + size_t chunk_rows = chunk.getNumRows(); + auto has_peer_group_ended = [&](size_t offset, size_t partition_end_offset, size_t chunk_rows_) + { return offset < partition_end_offset || end_offset < chunk_rows_; }; + auto try_end_peer_group + = [&](size_t peer_group_start_offset, size_t next_peer_group_start_offset, size_t partition_end_offset, size_t chunk_rows_) + { + if constexpr (function == WindowGroupLimitFunction::Rank) + { + current_peer_group_rows += next_peer_group_start_offset - peer_group_start_offset; + if (has_peer_group_ended(next_peer_group_start_offset, partition_end_offset, chunk_rows_)) + { + current_row_rank_value += current_peer_group_rows; + current_peer_group_rows = 0; + peer_group_start_row_columns = extractOneRowColumns(chunk, next_peer_group_start_offset); + } + } + else if constexpr (function == WindowGroupLimitFunction::DenseRank) + { + if (has_peer_group_ended(next_peer_group_start_offset, partition_end_offset, chunk_rows_)) + { + current_row_rank_value += 1; + peer_group_start_row_columns = extractOneRowColumns(chunk, next_peer_group_start_offset); + } + } + }; + + // This is a corner case. prev partition's last row is the last row of a chunk. + if (start_offset >= end_offset) + { + assert(!start_offset); + try_end_peer_group(start_offset, end_offset, end_offset, chunk_rows); + return; + } + + // row_number is simple + if constexpr (function == WindowGroupLimitFunction::RowNumber) + { + size_t rows = end_offset - start_offset; + size_t limit_remained = limit - current_row_rank_value + 1; + rows = rows > limit_remained ? limit_remained : rows; + insertResultValue(chunk, start_offset, rows); + current_row_rank_value += rows; + } + else + { + size_t peer_group_start_offset = start_offset; + while (peer_group_start_offset < end_offset && current_row_rank_value <= limit) + { + auto next_peer_group_start_offset = advanceNextPeerGroup(chunk, peer_group_start_offset, end_offset); + + insertResultValue(chunk, peer_group_start_offset, next_peer_group_start_offset - peer_group_start_offset); + try_end_peer_group(peer_group_start_offset, next_peer_group_start_offset, end_offset, chunk_rows); + peer_group_start_offset = next_peer_group_start_offset; + } + } + } + void insertResultValue(const DB::Chunk & chunk, size_t start_offset, size_t rows) + { + if (!rows) + return; + if (output_columns.empty()) + { + for (const auto & col : chunk.getColumns()) + { + output_columns.push_back(col->cloneEmpty()); + } + } + size_t i = 0; + for (const auto & col : chunk.getColumns()) + { + output_columns[i]->insertRangeFrom(*col, start_offset, rows); + i += 1; + } + } + size_t advanceNextPeerGroup(const DB::Chunk & chunk, size_t start_offset, size_t partition_end_offset) + { + if (peer_group_start_row_columns.empty()) + peer_group_start_row_columns = extractOneRowColumns(chunk, start_offset); + for (size_t i = start_offset; i < partition_end_offset; ++i) + { + if (!isRowEqual(sort_columns, peer_group_start_row_columns, 0, chunk.getColumns(), i)) + { + return i; + } + } + return partition_end_offset; + } +}; + +static DB::ITransformingStep::Traits getTraits() +{ + return DB::ITransformingStep::Traits{ + { + .preserves_number_of_streams = false, + .preserves_sorting = true, + }, + { + .preserves_number_of_rows = false, + }}; +} + +WindowGroupLimitStep::WindowGroupLimitStep( + const DB::DataStream & input_stream_, + const String & function_name_, + const std::vector partition_columns_, + const std::vector sort_columns_, + size_t limit_) + : DB::ITransformingStep(input_stream_, input_stream_.header, getTraits()) + , function_name(function_name_) + , partition_columns(partition_columns_) + , sort_columns(sort_columns_) + , limit(limit_) +{ +} + +void WindowGroupLimitStep::describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const +{ + if (!processors.empty()) + DB::IQueryPlanStep::describePipeline(processors, settings); +} + +void WindowGroupLimitStep::updateOutputStream() +{ + output_stream = createOutputStream(input_streams.front(), input_streams.front().header, getDataStreamTraits()); +} + + +void WindowGroupLimitStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & /*settings*/) +{ + if (function_name == "row_number") + { + pipeline.addSimpleTransform( + [&](const DB::Block & header) + { + return std::make_shared>( + header, partition_columns, sort_columns, limit); + }); + } + else if (function_name == "rank") + { + pipeline.addSimpleTransform( + [&](const DB::Block & header) { + return std::make_shared>( + header, partition_columns, sort_columns, limit); + }); + } + else if (function_name == "dense_rank") + { + pipeline.addSimpleTransform( + [&](const DB::Block & header) + { + return std::make_shared>( + header, partition_columns, sort_columns, limit); + }); + } + else + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupport function {} in WindowGroupLimit", function_name); + } +} +} diff --git a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h new file mode 100644 index 000000000000..bbbbf42abc55 --- /dev/null +++ b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include + +namespace local_engine +{ +class WindowGroupLimitStep : public DB::ITransformingStep +{ +public: + explicit WindowGroupLimitStep( + const DB::DataStream & input_stream_, + const String & function_name_, + const std::vector partition_columns_, + const std::vector sort_columns_, + size_t limit_); + ~WindowGroupLimitStep() override = default; + + String getName() const override { return "WindowGroupLimitStep"; } + + void transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & settings) override; + void describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const override; + void updateOutputStream() override; + +private: + // window function name, one of row_number, rank and dense_rank + String function_name; + std::vector partition_columns; + std::vector sort_columns; + size_t limit; +}; + +} diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp index 42d4f4d4d8cd..cc7738a15aaa 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp @@ -14,25 +14,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include "AdvancedParametersParseUtil.h" #include #include -#include -#include +#include #include +#include #include + namespace DB::ErrorCodes { - extern const int BAD_ARGUMENTS; +extern const int BAD_ARGUMENTS; } namespace local_engine { -template +template void tryAssign(const std::unordered_map & kvs, const String & key, T & v); -template<> +template <> void tryAssign(const std::unordered_map & kvs, const String & key, String & v) { auto it = kvs.find(key); @@ -40,7 +41,7 @@ void tryAssign(const std::unordered_map & kvs, const Str v = it->second; } -template<> +template <> void tryAssign(const std::unordered_map & kvs, const String & key, bool & v) { auto it = kvs.find(key); @@ -57,7 +58,7 @@ void tryAssign(const std::unordered_map & kvs, const Strin } } -template<> +template <> void tryAssign(const std::unordered_map & kvs, const String & key, Int64 & v) { auto it = kvs.find(key); @@ -94,9 +95,9 @@ void readStringUntilCharsInto(String & s, DB::ReadBuffer & buf) std::unordered_map> convertToKVs(const String & advance) { std::unordered_map> res; - std::unordered_map *kvs; + std::unordered_map * kvs; DB::ReadBufferFromString in(advance); - while(!in.eof()) + while (!in.eof()) { String key; readStringUntilCharsInto<'=', '\n', ':'>(key, in); @@ -146,5 +147,13 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const String & advance) tryAssign(kvs, "numPartitions", info.partitions_num); return info; } -} +WindowGroupOptimizationInfo WindowGroupOptimizationInfo::parse(const String & advance) +{ + WindowGroupOptimizationInfo info; + auto kkvs = convertToKVs(advance); + auto & kvs = kkvs["WindowGroupLimitParameters"]; + tryAssign(kvs, "window_function", info.window_function); + return info; +} +} diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h index 5f6fe6d256e3..fc478db33bfd 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h @@ -16,10 +16,10 @@ */ #pragma once #include +#include namespace local_engine { - std::unordered_map> convertToKVs(const String & advance); @@ -38,5 +38,10 @@ struct JoinOptimizationInfo static JoinOptimizationInfo parse(const String & advance); }; -} +struct WindowGroupOptimizationInfo +{ + String window_function; + static WindowGroupOptimizationInfo parse(const String & advnace); +}; +} diff --git a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp index 153918850ff9..f6c10386f405 100644 --- a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp +++ b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp @@ -15,127 +15,61 @@ * limitations under the License. */ +#include "WindowGroupLimitRelParser.h" #include +#include +#include #include #include #include -#include #include +#include +#include "AdvancedParametersParseUtil.h" namespace DB::ErrorCodes { extern const int BAD_ARGUMENTS; } -const static String FUNCTION_ROW_NUM = "top_row_number"; -const static String FUNCTION_RANK = "top_rank"; -const static String FUNCTION_DENSE_RANK = "top_dense_rank"; - namespace local_engine { WindowGroupLimitRelParser::WindowGroupLimitRelParser(SerializedPlanParser * plan_parser_) : RelParser(plan_parser_) { - LOG_ERROR(getLogger("WindowGroupLimitRelParser"), "xxx new parrser"); } + DB::QueryPlanPtr WindowGroupLimitRelParser::parse(DB::QueryPlanPtr current_plan_, const substrait::Rel & rel, std::list & rel_stack_) { const auto win_rel_def = rel.windowgrouplimit(); - current_plan = std::move(current_plan_); - - DB::Block output_header = current_plan->getCurrentDataStream().header; - - window_function_name = FUNCTION_ROW_NUM; - LOG_ERROR(getLogger("WindowGroupLimitRelParser"), "xxx input header: {}", current_plan->getCurrentDataStream().header.dumpStructure()); + google::protobuf::StringValue optimize_info_str; + optimize_info_str.ParseFromString(win_rel_def.advanced_extension().optimization().value()); + auto optimization_info = WindowGroupOptimizationInfo::parse(optimize_info_str.value()); + window_function_name = optimization_info.window_function; - /// Only one window function in one window group limit - auto win_desc = buildWindowDescription(win_rel_def); + current_plan = std::move(current_plan_); - auto win_step = std::make_unique(current_plan->getCurrentDataStream(), win_desc, win_desc.window_functions, false); - win_step->setStepDescription("Window Group Limit " + win_desc.window_name); - steps.emplace_back(win_step.get()); - current_plan->addStep(std::move(win_step)); + auto partition_fields = parsePartitoinFields(win_rel_def.partition_expressions()); + auto sort_fields = parseSortFields(win_rel_def.sorts()); + size_t limit = static_cast(win_rel_def.limit()); - /// remove the window function result column which is not needed in later steps - DB::ActionsDAG post_project_actions_dag = DB::ActionsDAG::makeConvertingActions( - current_plan->getCurrentDataStream().header.getColumnsWithTypeAndName(), - output_header.getColumnsWithTypeAndName(), - DB::ActionsDAG::MatchColumnsMode::Name); - auto post_project_step - = std::make_unique(current_plan->getCurrentDataStream(), std::move(post_project_actions_dag)); - post_project_step->setStepDescription("Window group limit: drop window function result column"); - steps.emplace_back(post_project_step.get()); - current_plan->addStep(std::move(post_project_step)); + auto window_group_limit_step = std::make_unique( + current_plan->getCurrentDataStream(), window_function_name, partition_fields, sort_fields, limit); + window_group_limit_step->setStepDescription("Window group limit"); + steps.emplace_back(window_group_limit_step.get()); + current_plan->addStep(std::move(window_group_limit_step)); - LOG_ERROR(getLogger("WindowGroupLimitRelParser"), "xxx output header: {}", current_plan->getCurrentDataStream().header.dumpStructure()); return std::move(current_plan); } -DB::WindowFrame WindowGroupLimitRelParser::buildWindowFrame(const String & function_name) +std::vector +WindowGroupLimitRelParser::parsePartitoinFields(const google::protobuf::RepeatedPtrField & expressions) { - // We only need first rows, so let the begin type is unbounded is OK - DB::WindowFrame frame; - if (function_name == FUNCTION_ROW_NUM) - { - frame.type = DB::WindowFrame::FrameType::ROWS; - frame.begin_type = DB::WindowFrame::BoundaryType::Unbounded; - frame.begin_offset = 0; - frame.begin_preceding = true; - frame.end_type = DB::WindowFrame::BoundaryType::Current; - frame.end_offset = 0; - frame.end_preceding = true; - } - else if (function_name == FUNCTION_RANK || function_name == FUNCTION_DENSE_RANK) - { - // rank and dense_rank can only work on range mode - frame.type = DB::WindowFrame::FrameType::RANGE; - frame.begin_type = DB::WindowFrame::BoundaryType::Unbounded; - frame.begin_offset = 0; - frame.begin_preceding = true; - frame.end_type = DB::WindowFrame::BoundaryType::Current; - frame.end_offset = 0; - frame.end_preceding = true; - } - else - { - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown function {} for window group limit", function_name); - } - - return frame; -} - -DB::WindowDescription WindowGroupLimitRelParser::buildWindowDescription(const substrait::WindowGroupLimitRel & win_rel_def) -{ - DB::WindowDescription win_desc; - win_desc.frame = buildWindowFrame(window_function_name); - win_desc.partition_by = parsePartitionBy(win_rel_def.partition_expressions()); - win_desc.order_by = SortRelParser::parseSortDescription(win_rel_def.sorts(), current_plan->getCurrentDataStream().header); - win_desc.full_sort_description = win_desc.partition_by; - win_desc.full_sort_description.insert(win_desc.full_sort_description.end(), win_desc.order_by.begin(), win_desc.order_by.end()); - - DB::WriteBufferFromOwnString ss; - ss << "partition by " << DB::dumpSortDescription(win_desc.partition_by); - ss << "order by " << DB::dumpSortDescription(win_desc.order_by); - ss << win_desc.frame.toString(); - win_desc.window_name = ss.str(); - - win_desc.window_functions.emplace_back(buildWindowFunctionDescription(window_function_name, static_cast(win_rel_def.limit()))); - - return win_desc; -} - -DB::SortDescription -WindowGroupLimitRelParser::parsePartitionBy(const google::protobuf::RepeatedPtrField & expressions) -{ - DB::Block header = current_plan->getCurrentDataStream().header; - DB::SortDescription sort_desc; + std::vector fields; for (const auto & expr : expressions) { if (expr.has_selection()) { - auto pos = expr.selection().direct_reference().struct_field().field(); - auto col_name = header.getByPosition(pos).name; - sort_desc.push_back(DB::SortColumnDescription(col_name, 1, 1)); + fields.push_back(static_cast(expr.selection().direct_reference().struct_field().field())); } else if (expr.has_literal()) { @@ -143,28 +77,33 @@ WindowGroupLimitRelParser::parsePartitionBy(const google::protobuf::RepeatedPtrF } else { - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow partition argument: {}", expr.DebugString()); + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow expression: {}", expr.DebugString()); } } - return sort_desc; + return fields; } -DB::WindowFunctionDescription WindowGroupLimitRelParser::buildWindowFunctionDescription(const String & function_name, size_t limit) +std::vector WindowGroupLimitRelParser::parseSortFields(const google::protobuf::RepeatedPtrField & sort_fields) { - DB::WindowFunctionDescription desc; - desc.column_name = function_name; - desc.function_node = nullptr; - DB::AggregateFunctionProperties func_properties; - DB::Names func_args; - DB::DataTypes func_args_types; - DB::Array func_params; - func_params.push_back(limit); - auto func_ptr = RelParser::getAggregateFunction(function_name, func_args_types, func_properties, func_params); - desc.argument_names = func_args; - desc.argument_types = func_args_types; - desc.aggregate_function = func_ptr; - return desc; + std::vector fields; + for (const auto sort_field : sort_fields) + { + if (sort_field.expr().has_literal()) + { + continue; + } + else if (sort_field.expr().has_selection()) + { + fields.push_back(static_cast(sort_field.expr().selection().direct_reference().struct_field().field())); + } + else + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown expression: {}", sort_field.expr().DebugString()); + } + } + return fields; } + void registerWindowGroupLimitRelParser(RelParserFactory & factory) { auto builder = [](SerializedPlanParser * plan_parser) { return std::make_shared(plan_parser); }; diff --git a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h index 6b7d3bbf33e5..c9c503ed4745 100644 --- a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h +++ b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h @@ -46,12 +46,7 @@ class WindowGroupLimitRelParser : public RelParser DB::QueryPlanPtr current_plan; String window_function_name; - DB::WindowDescription buildWindowDescription(const substrait::WindowGroupLimitRel & win_rel_def); - /// There is only one type of window frame at present. - static DB::WindowFrame buildWindowFrame(const String & function_name); - - DB::SortDescription parsePartitionBy(const google::protobuf::RepeatedPtrField & expressions); - - static DB::WindowFunctionDescription buildWindowFunctionDescription(const String & function_name, size_t limit); + std::vector parsePartitoinFields(const google::protobuf::RepeatedPtrField & expressions); + std::vector parseSortFields(const google::protobuf::RepeatedPtrField & sort_fields); }; } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index a55926d76d12..dd4150806cfc 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ArrowEvalPythonExec +import org.apache.spark.sql.execution.window._ import org.apache.spark.sql.hive.{HiveTableScanExecTransformer, HiveUDFTransformer} import org.apache.spark.sql.types.{DecimalType, LongType, NullType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -678,6 +679,15 @@ trait SparkPlanExecApi { } } + def genWindowGroupLimitTransformer( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + rankLikeFunction: Expression, + limit: Int, + mode: WindowGroupLimitMode, + child: SparkPlan): SparkPlan = + WindowGroupLimitExecTransformer(partitionSpec, orderSpec, rankLikeFunction, limit, mode, child) + def genHiveUDFTransformer( expr: Expression, attributeSeq: Seq[Attribute]): ExpressionTransformer = { diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala index 6047789e6abe..5440481f8ddd 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala @@ -358,7 +358,7 @@ object OffloadOthers { val windowGroupLimitPlan = SparkShimLoader.getSparkShims .getWindowGroupLimitExecShim(plan) .asInstanceOf[WindowGroupLimitExecShim] - WindowGroupLimitExecTransformer( + BackendsApiManager.getSparkPlanExecApiInstance.genWindowGroupLimitTransformer( windowGroupLimitPlan.partitionSpec, windowGroupLimitPlan.orderSpec, windowGroupLimitPlan.rankLikeFunction,