diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index c204c6d6b73f..e3d0a9bbba54 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -369,6 +369,7 @@ object CHBackendSettings extends BackendSettingsApi with Logging { ) } + // If the partition keys are high cardinality, the aggregation method is slower. def enableConvertWindowGroupLimitToAggregate(): Boolean = { SparkEnv.get.conf.getBoolean( CHConf.runtimeConfig("enable_window_group_limit_to_aggregate"), diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala index 2214c79b20b4..573723e2c783 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -3162,62 +3162,66 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr } test("GLUTEN-7905 get topk of window by aggregate") { - def checkWindowGroupLimit(df: DataFrame): Unit = { - val expands = collectWithSubqueries(df.queryExecution.executedPlan) { - case e: ExpandExecTransformer - if (e.child.isInstanceOf[CHAggregateGroupLimitExecTransformer]) => - e + withSQLConf(( + "spark.gluten.sql.columnar.backend.ch.runtime_config.enable_window_group_limit_to_aggregate", + "true")) { + def checkWindowGroupLimit(df: DataFrame): Unit = { + val expands = collectWithSubqueries(df.queryExecution.executedPlan) { + case e: ExpandExecTransformer + if (e.child.isInstanceOf[CHAggregateGroupLimitExecTransformer]) => + e + } + assert(expands.size == 1) } - assert(expands.size == 1) + spark.sql("create table test_win_top (a string, b int, c int) using parquet") + spark.sql(""" + |insert into test_win_top values + |('a', 3, 3), ('a', 1, 5), ('a', 2, 2), ('a', null, null), ('a', null, 1), + |('b', 1, 1), ('b', 2, 1), + |('c', 2, 3) + |""".stripMargin) + compareResultsAgainstVanillaSpark( + """ + |select a, b, c, row_number() over (partition by a order by b desc nulls first) as r + |from test_win_top + |""".stripMargin, + true, + checkWindowGroupLimit + ) + compareResultsAgainstVanillaSpark( + """ + |select a, b, c, row_number() over (partition by a order by b desc, c nulls last) as r + |from test_win_top + |""".stripMargin, + true, + checkWindowGroupLimit + ) + compareResultsAgainstVanillaSpark( + """ + |select a, b, c, row_number() over (partition by a order by b asc nulls first, c) as r + |from test_win_top + |""".stripMargin, + true, + checkWindowGroupLimit + ) + compareResultsAgainstVanillaSpark( + """ + |select a, b, c, row_number() over (partition by a order by b asc nulls last) as r + |from test_win_top + |""".stripMargin, + true, + checkWindowGroupLimit + ) + compareResultsAgainstVanillaSpark( + """ + |select a, b, c, row_number() over (partition by a order by b , c) as r + |from test_win_top + |""".stripMargin, + true, + checkWindowGroupLimit + ) + spark.sql("drop table if exists test_win_top") } - spark.sql("create table test_win_top (a string, b int, c int) using parquet") - spark.sql(""" - |insert into test_win_top values - |('a', 3, 3), ('a', 1, 5), ('a', 2, 2), ('a', null, null), ('a', null, 1), - |('b', 1, 1), ('b', 2, 1), - |('c', 2, 3) - |""".stripMargin) - compareResultsAgainstVanillaSpark( - """ - |select a, b, c, row_number() over (partition by a order by b desc nulls first) as r - |from test_win_top - |""".stripMargin, - true, - checkWindowGroupLimit - ) - compareResultsAgainstVanillaSpark( - """ - |select a, b, c, row_number() over (partition by a order by b desc, c nulls last) as r - |from test_win_top - |""".stripMargin, - true, - checkWindowGroupLimit - ) - compareResultsAgainstVanillaSpark( - """ - |select a, b, c, row_number() over (partition by a order by b asc nulls first) as r - |from test_win_top - |""".stripMargin, - true, - checkWindowGroupLimit - ) - compareResultsAgainstVanillaSpark( - """ - |select a, b, c, row_number() over (partition by a order by b asc nulls last) as r - |from test_win_top - |""".stripMargin, - true, - checkWindowGroupLimit - ) - compareResultsAgainstVanillaSpark( - """ - |select a, b, c, row_number() over (partition by a order by b , c) as r - |from test_win_top - |""".stripMargin, - true, - checkWindowGroupLimit - ) - spark.sql("drop table if exists test_win_top") } diff --git a/cpp-ch/local-engine/AggregateFunctions/GroupLimitFunctions.cpp b/cpp-ch/local-engine/AggregateFunctions/GroupLimitFunctions.cpp index 2884b687e335..137ae8a54489 100644 --- a/cpp-ch/local-engine/AggregateFunctions/GroupLimitFunctions.cpp +++ b/cpp-ch/local-engine/AggregateFunctions/GroupLimitFunctions.cpp @@ -41,6 +41,7 @@ #include #include +#include "base/defines.h" namespace DB::ErrorCodes { @@ -72,7 +73,6 @@ struct RowNumGroupArraySortedData const auto & pos = sort_order.pos; const auto & asc = sort_order.direction; const auto & nulls_first = sort_order.nulls_direction; - LOG_ERROR(getLogger("GroupLimitFunction"), "xxx pos: {} tuple size: {} {}", pos, rhs.size(), lhs.size()); bool l_is_null = lhs[pos].isNull(); bool r_is_null = rhs[pos].isNull(); if (l_is_null && r_is_null) @@ -120,25 +120,17 @@ struct RowNumGroupArraySortedData values[current_index] = current; } - ALWAYS_INLINE void addElement(const Data & data, const SortOrderFields & sort_orders, size_t max_elements) + ALWAYS_INLINE void addElement(const Data && data, const SortOrderFields & sort_orders, size_t max_elements) { if (values.size() >= max_elements) { - LOG_ERROR( - getLogger("GroupLimitFunction"), - "xxxx values size: {}, limit: {}, tuple size: {} {}", - values.size(), - max_elements, - data.size(), - values[0].size()); if (!compare(data, values[0], sort_orders)) return; values[0] = data; heapReplaceTop(sort_orders); return; } - values.push_back(data); - LOG_ERROR(getLogger("GroupLimitFunction"), "add new element: {} {}", values.size(), values.back().size()); + values.emplace_back(std::move(data)); auto cmp = [&sort_orders](const Data & a, const Data & b) { return compare(a, b, sort_orders); }; std::push_heap(values.begin(), values.end(), cmp); } @@ -190,7 +182,7 @@ class RowNumGroupArraySorted final : public DB::IAggregateFunctionDataHelper( - {data_type}, parameters_, getRowNumReultDataType(data_type)) + {data_type}, parameters_, getRowNumReultDataType(data_type)) { if (parameters_.size() != 2) throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "{} needs two parameters: limit and order clause", getName()); @@ -212,23 +204,14 @@ class RowNumGroupArraySorted final : public DB::IAggregateFunctionDataHelperdata(place); DB::Tuple data_tuple = (*columns[0])[row_num].safeGet(); - // const DB::Tuple & data_tuple = *(static_cast(&((*columns[0])[row_num]))); - LOG_ERROR( - getLogger("GroupLimitFunction"), - "xxx col len: {}, row num: {}, tuple size: {}, type: {}", - columns[0]->size(), - row_num, - data_tuple.size(), - (*columns[0])[row_num].getType()); - ; - this->data(place).addElement(data_tuple, sort_order_fields, limit); + this->data(place).addElement(std::move(data_tuple), sort_order_fields, limit); } void merge(DB::AggregateDataPtr __restrict place, DB::ConstAggregateDataPtr rhs, DB::Arena * /*arena*/) const override { auto & rhs_values = this->data(rhs).values; for (auto & rhs_element : rhs_values) - this->data(place).addElement(rhs_element, sort_order_fields, limit); + this->data(place).addElement(std::move(rhs_element), sort_order_fields, limit); } void serialize(DB::ConstAggregateDataPtr __restrict place, DB::WriteBuffer & buf, std::optional /* version */) const override diff --git a/cpp-ch/local-engine/Common/AggregateUtil.cpp b/cpp-ch/local-engine/Common/AggregateUtil.cpp index 2290747fa158..0707d18aa01b 100644 --- a/cpp-ch/local-engine/Common/AggregateUtil.cpp +++ b/cpp-ch/local-engine/Common/AggregateUtil.cpp @@ -47,7 +47,6 @@ extern const SettingsBool enable_memory_bound_merging_of_aggregation_results; extern const SettingsUInt64 aggregation_in_order_max_block_bytes; extern const SettingsUInt64 group_by_two_level_threshold; extern const SettingsFloat min_hit_rate_to_use_consecutive_keys_optimization; -extern const SettingsMaxThreads max_threads; extern const SettingsUInt64 max_block_size; } diff --git a/cpp-ch/local-engine/Common/ArrayJoinHelper.cpp b/cpp-ch/local-engine/Common/ArrayJoinHelper.cpp index de32747690a0..acefad0aea2a 100644 --- a/cpp-ch/local-engine/Common/ArrayJoinHelper.cpp +++ b/cpp-ch/local-engine/Common/ArrayJoinHelper.cpp @@ -150,6 +150,21 @@ addArrayJoinStep(DB::ContextPtr context, DB::QueryPlan & plan, const DB::Actions steps.emplace_back(array_join_step.get()); plan.addStep(std::move(array_join_step)); // LOG_DEBUG(logger, "plan2:{}", PlanUtil::explainPlan(*query_plan)); + + /// Post-projection after array join(Optional) + if (!ignore_actions_dag(splitted_actions_dags.after_array_join)) + { + auto step_after_array_join + = std::make_unique(plan.getCurrentHeader(), std::move(splitted_actions_dags.after_array_join)); + step_after_array_join->setStepDescription("Post-projection In Generate"); + steps.emplace_back(step_after_array_join.get()); + plan.addStep(std::move(step_after_array_join)); + // LOG_DEBUG(logger, "plan3:{}", PlanUtil::explainPlan(*query_plan)); + } + } + else + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Expect array join node in actions_dag"); } return steps; diff --git a/cpp-ch/local-engine/Common/GlutenConfig.cpp b/cpp-ch/local-engine/Common/GlutenConfig.cpp index 0cefbc383977..ce15a12f921a 100644 --- a/cpp-ch/local-engine/Common/GlutenConfig.cpp +++ b/cpp-ch/local-engine/Common/GlutenConfig.cpp @@ -141,4 +141,13 @@ MergeTreeCacheConfig MergeTreeCacheConfig::loadFromContext(const DB::ContextPtr config.enable_data_prefetch = context->getConfigRef().getBool(ENABLE_DATA_PREFETCH, config.enable_data_prefetch); return config; } -} \ No newline at end of file + +WindowConfig WindowConfig::loadFromContext(const DB::ContextPtr & context) +{ + WindowConfig config; + config.aggregate_topk_sample_rows = context->getConfigRef().getUInt64(WINDOW_AGGREGATE_TOPK_SAMPLE_ROWS, 5000); + config.aggregate_topk_high_cardinality_threshold + = context->getConfigRef().getDouble(WINDOW_AGGREGATE_TOPK_HIGH_CARDINALITY_THRESHOLD, 0.6); + return config; +} +} diff --git a/cpp-ch/local-engine/Common/GlutenConfig.h b/cpp-ch/local-engine/Common/GlutenConfig.h index 8af83329b6b7..85839b70ecd2 100644 --- a/cpp-ch/local-engine/Common/GlutenConfig.h +++ b/cpp-ch/local-engine/Common/GlutenConfig.h @@ -56,9 +56,12 @@ struct GraceMergingAggregateConfig { inline static const String MAX_GRACE_AGGREGATE_MERGING_BUCKETS = "max_grace_aggregate_merging_buckets"; inline static const String THROW_ON_OVERFLOW_GRACE_AGGREGATE_MERGING_BUCKETS = "throw_on_overflow_grace_aggregate_merging_buckets"; - inline static const String AGGREGATED_KEYS_BEFORE_EXTEND_GRACE_AGGREGATE_MERGING_BUCKETS = "aggregated_keys_before_extend_grace_aggregate_merging_buckets"; - inline static const String MAX_PENDING_FLUSH_BLOCKS_PER_GRACE_AGGREGATE_MERGING_BUCKET = "max_pending_flush_blocks_per_grace_aggregate_merging_bucket"; - inline static const String MAX_ALLOWED_MEMORY_USAGE_RATIO_FOR_AGGREGATE_MERGING = "max_allowed_memory_usage_ratio_for_aggregate_merging"; + inline static const String AGGREGATED_KEYS_BEFORE_EXTEND_GRACE_AGGREGATE_MERGING_BUCKETS + = "aggregated_keys_before_extend_grace_aggregate_merging_buckets"; + inline static const String MAX_PENDING_FLUSH_BLOCKS_PER_GRACE_AGGREGATE_MERGING_BUCKET + = "max_pending_flush_blocks_per_grace_aggregate_merging_bucket"; + inline static const String MAX_ALLOWED_MEMORY_USAGE_RATIO_FOR_AGGREGATE_MERGING + = "max_allowed_memory_usage_ratio_for_aggregate_merging"; size_t max_grace_aggregate_merging_buckets = 32; bool throw_on_overflow_grace_aggregate_merging_buckets = false; @@ -73,7 +76,8 @@ struct StreamingAggregateConfig { inline static const String AGGREGATED_KEYS_BEFORE_STREAMING_AGGREGATING_EVICT = "aggregated_keys_before_streaming_aggregating_evict"; inline static const String MAX_MEMORY_USAGE_RATIO_FOR_STREAMING_AGGREGATING = "max_memory_usage_ratio_for_streaming_aggregating"; - inline static const String HIGH_CARDINALITY_THRESHOLD_FOR_STREAMING_AGGREGATING = "high_cardinality_threshold_for_streaming_aggregating"; + inline static const String HIGH_CARDINALITY_THRESHOLD_FOR_STREAMING_AGGREGATING + = "high_cardinality_threshold_for_streaming_aggregating"; inline static const String ENABLE_STREAMING_AGGREGATING = "enable_streaming_aggregating"; size_t aggregated_keys_before_streaming_aggregating_evict = 1024; @@ -154,6 +158,16 @@ struct MergeTreeCacheConfig static MergeTreeCacheConfig loadFromContext(const DB::ContextPtr & context); }; +struct WindowConfig +{ +public: + inline static const String WINDOW_AGGREGATE_TOPK_SAMPLE_ROWS = "window.aggregate_topk_sample_rows"; + inline static const String WINDOW_AGGREGATE_TOPK_HIGH_CARDINALITY_THRESHOLD = "window.aggregate_topk_high_cardinality_threshold"; + size_t aggregate_topk_sample_rows = 5000; + double aggregate_topk_high_cardinality_threshold = 0.6; + static WindowConfig loadFromContext(const DB::ContextPtr & context); +}; + namespace PathConfig { inline constexpr const char * USE_CURRENT_DIRECTORY_AS_TMP = "use_current_directory_as_tmp"; diff --git a/cpp-ch/local-engine/Operator/BranchStep.cpp b/cpp-ch/local-engine/Operator/BranchStep.cpp new file mode 100644 index 000000000000..faf05ebeaeb3 --- /dev/null +++ b/cpp-ch/local-engine/Operator/BranchStep.cpp @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "BranchStep.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ + +class BranchOutputTransform : public DB::IProcessor +{ +public: + using Status = DB::IProcessor::Status; + BranchOutputTransform(const DB::Block & header_) : DB::IProcessor({header_}, {header_}) { } + ~BranchOutputTransform() override = default; + + String getName() const override { return "BranchOutputTransform"; } + + Status prepare() override; + void work() override; + +private: + bool has_output = false; + DB::Chunk output_chunk; + bool has_input = false; + DB::Chunk input_chunk; +}; + +BranchOutputTransform::Status BranchOutputTransform::prepare() +{ + auto & output = outputs.front(); + auto & input = inputs.front(); + if (output.isFinished()) + { + input.close(); + return Status::Finished; + } + if (has_output) + { + if (output.canPush()) + { + output.push(std::move(output_chunk)); + has_output = false; + } + return Status::PortFull; + } + if (has_input) + return Status::Ready; + if (input.isFinished()) + { + output.finish(); + return Status::Finished; + } + input.setNeeded(); + if (!input.hasData()) + return Status::NeedData; + input_chunk = input.pull(true); + has_input = true; + return Status::Ready; +} + +void BranchOutputTransform::work() +{ + if (has_input) + { + output_chunk = std::move(input_chunk); + has_output = true; + has_input = false; + } +} + +class BranchHookSource : public DB::IProcessor +{ +public: + using Status = DB::IProcessor::Status; + BranchHookSource(const DB::Block & header_) : DB::IProcessor({}, {header_}) { inner_inputs.emplace_back(header_, this); } + ~BranchHookSource() override = default; + + String getName() const override { return "BranchHookSource"; } + + Status prepare() override; + void work() override; + void enableInputs() { inputs.swap(inner_inputs); } + +private: + DB::InputPorts inner_inputs; + bool has_output = false; + DB::Chunk output_chunk; + bool has_input = false; + DB::Chunk input_chunk; +}; + +BranchHookSource::Status BranchHookSource::prepare() +{ + auto & output = outputs.front(); + auto & input = inputs.front(); + if (output.isFinished()) + { + input.close(); + return Status::Finished; + } + if (has_output) + { + if (output.canPush()) + { + output.push(std::move(output_chunk)); + has_output = false; + } + return Status::PortFull; + } + if (has_input) + return Status::Ready; + if (input.isFinished()) + { + output.finish(); + return Status::Finished; + } + input.setNeeded(); + if (!input.hasData()) + return Status::NeedData; + input_chunk = input.pull(true); + has_input = true; + return Status::Ready; +} + +void BranchHookSource::work() +{ + if (has_input) + { + output_chunk = std::move(input_chunk); + has_output = true; + has_input = false; + } +} + + +static DB::ITransformingStep::Traits getTraits() +{ + return DB::ITransformingStep::Traits{ + { + .returns_single_stream = true, + .preserves_number_of_streams = false, + .preserves_sorting = false, + }, + { + .preserves_number_of_rows = false, + }}; +} + +class ResizeStep : public DB::ITransformingStep +{ +public: + explicit ResizeStep(const DB::Block & header_, size_t num_streams_) + : DB::ITransformingStep(header_, header_, getTraits()), num_streams(num_streams_) + { + } + ~ResizeStep() override = default; + + String getName() const override { return "UniteBranchesStep"; } + + void transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings &) override + { + LOG_ERROR(getLogger("ResizeStep"), "xxx num_streams: {}", num_streams); + pipeline.resize(num_streams); + } + void describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const override + { + if (!processors.empty()) + DB::IQueryPlanStep::describePipeline(processors, settings); + } + +private: + size_t num_streams; + void updateOutputHeader() override {}; +}; + +DB::QueryPlanPtr BranchStepHelper::createSubPlan(const DB::Block & header, size_t num_streams) +{ + auto source = std::make_unique(DB::Pipe(std::make_shared(header))); + source->setStepDescription("Hook node connected to one branch output"); + auto plan = std::make_unique(); + plan->addStep(std::move(source)); + + if (num_streams > 1) + { + auto resize_step = std::make_unique(plan->getCurrentHeader(), num_streams); + plan->addStep(std::move(resize_step)); + } + return std::move(plan); +} + +StaticBranchStep::StaticBranchStep( + DB::ContextPtr context_, const DB::Block & header_, size_t branches_, size_t sample_rows_, BranchSelector selector_) + : DB::ITransformingStep(header_, header_, getTraits()) + , context(context_) + , header(header_) + , branches(branches_) + , max_sample_rows(sample_rows_) + , selector(selector_) +{ +} + +void StaticBranchStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & settings) +{ + auto build_transform = [&](DB::OutputPortRawPtrs child_outputs) -> DB::Processors + { + DB::Processors new_processors; + for (auto & output : child_outputs) + { + if (!output) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Output port is null"); + auto branch_transform = std::make_shared(header, max_sample_rows, branches, selector); + DB::connect(*output, branch_transform->getInputs().front()); + new_processors.push_back(branch_transform); + + for (auto & branch_output : branch_transform->getOutputs()) + { + auto branch_processor = std::make_shared(header); + DB::connect(branch_output, branch_processor->getInputs().front()); + new_processors.push_back(branch_processor); + } + } + return new_processors; + }; + pipeline.resize(1); + pipeline.transform(build_transform); +} + +void StaticBranchStep::describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const +{ + if (!processors.empty()) + DB::IQueryPlanStep::describePipeline(processors, settings); +} + +void StaticBranchStep::updateOutputHeader() +{ +} + +UniteBranchesStep::UniteBranchesStep(const DB::Block & header_, std::vector && branch_plans_, size_t num_streams_) + : DB::ITransformingStep(header_, branch_plans_[0]->getCurrentHeader(), getTraits()), header(header_) +{ + branch_plans.swap(branch_plans_); + size_t branches = branch_plans.size(); + num_streams = num_streams_; +} + +void UniteBranchesStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings &) +{ + auto add_transform = [&](DB::OutputPortRawPtrs child_outputs) -> DB::Processors + { + DB::Processors new_processors; + size_t branch_index = 0; + for (auto output : child_outputs) + { + auto & branch_plan = branch_plans[branch_index]; + DB::QueryPlanOptimizationSettings optimization_settings; + DB::BuildQueryPipelineSettings build_settings; + DB::QueryPlanResourceHolder resource_holder; + + auto pipeline_builder = branch_plan->buildQueryPipeline(optimization_settings, build_settings); + auto pipe = DB::QueryPipelineBuilder::getPipe(std::move(*pipeline_builder), resource_holder); + DB::ProcessorPtr source_node = nullptr; + auto processors = DB::Pipe::detachProcessors(std::move(pipe)); + for (auto processor : processors) + { + if (auto * source = typeid_cast(processor.get())) + { + if (source->getInputs().empty()) + { + if (source_node) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "There is multi source in branch plan"); + source->enableInputs(); + source_node = processor; + } + } + } + if (!source_node) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Cannot find source node in branch plan"); + if (source_node->getInputs().empty()) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Source node has no input"); + DB::connect(*output, source_node->getInputs().front()); + new_processors.insert(new_processors.end(), processors.begin(), processors.end()); + branch_index++; + } + return new_processors; + }; + pipeline.transform(add_transform); + pipeline.resize(1); + if (num_streams > 1) + pipeline.resize(num_streams); +} + +void UniteBranchesStep::describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const +{ + if (!processors.empty()) + DB::IQueryPlanStep::describePipeline(processors, settings); +} +} diff --git a/cpp-ch/local-engine/Operator/BranchStep.h b/cpp-ch/local-engine/Operator/BranchStep.h new file mode 100644 index 000000000000..17056aa9b5a6 --- /dev/null +++ b/cpp-ch/local-engine/Operator/BranchStep.h @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "Processors/Port.h" +#include "Processors/QueryPlan/QueryPlan.h" + +namespace local_engine +{ + +class BranchStepHelper +{ +public: + // Create a new query plan that would be used to build sub branch query plan. + static DB::QueryPlanPtr createSubPlan(const DB::Block & header, size_t num_streams); +}; + +// Use to branch the query plan. +class StaticBranchStep : public DB::ITransformingStep +{ +public: + using BranchSelector = std::function &)>; + explicit StaticBranchStep( + DB::ContextPtr context_, const DB::Block & header, size_t branches, size_t sample_rows, BranchSelector selector); + ~StaticBranchStep() override = default; + + String getName() const override { return "StaticBranchStep"; } + + // This will resize the num_streams to 1. You may need to resize after this. + void transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & settings) override; + void describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const override; + +protected: + void updateOutputHeader() override; + +private: + DB::ContextPtr context; + DB::Block header; + size_t max_sample_rows; + size_t branches; + BranchSelector selector; +}; + +class UniteBranchesStep : public DB::ITransformingStep +{ +public: + explicit UniteBranchesStep(const DB::Block & header_, std::vector && branch_plans_, size_t num_streams_); + ~UniteBranchesStep() override = default; + + String getName() const override { return "UniteBranchesStep"; } + + void transformPipeline(DB::QueryPipelineBuilder & pipelines, const DB::BuildQueryPipelineSettings &) override; + void describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const override; + +private: + DB::Block header; + std::vector branch_plans; + size_t num_streams; + + void updateOutputHeader() override { output_header = header; }; +}; + +} diff --git a/cpp-ch/local-engine/Operator/BranchTransform.cpp b/cpp-ch/local-engine/Operator/BranchTransform.cpp new file mode 100644 index 000000000000..f923f4ac4b41 --- /dev/null +++ b/cpp-ch/local-engine/Operator/BranchTransform.cpp @@ -0,0 +1,155 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "BranchTransform.h" +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ +extern const int LOGICAL_ERROR; +} + +namespace local_engine +{ +static DB::OutputPorts buildOutputPorts(const DB::Block & header, size_t branches) +{ + DB::OutputPorts output_ports; + for (size_t i = 0; i < branches; ++i) + output_ports.emplace_back(header); + return output_ports; +} +StaticBranchTransform::StaticBranchTransform(const DB::Block & header_, size_t sample_rows_, size_t branches_, BranchSelector selector_) + : DB::IProcessor({header_}, buildOutputPorts(header_, branches_)), max_sample_rows(sample_rows_), selector(selector_) +{ +} + +static bool existFinishedOutput(const DB::OutputPorts & output_ports) +{ + for (const auto & output_port : output_ports) + if (output_port.isFinished()) + return true; + return false; +} + +StaticBranchTransform::Status StaticBranchTransform::prepare() +{ + auto & input = inputs.front(); + if ((selected_output_port && selected_output_port->isFinished()) || (!selected_output_port && existFinishedOutput(outputs))) + { + input.close(); + return Status::Finished; + } + + if (has_output) + { + assert(selected_output_port != nullptr); + if (selected_output_port->canPush()) + { + selected_output_port->push(std::move(output_chunk)); + has_output = false; + } + return Status::PortFull; + } + + if (has_input || (selected_output_port && !sample_chunks.empty())) + { + // to clear the pending chunks + return Status::Ready; + } + + if (input.isFinished()) + { + if (!sample_chunks.empty()) + { + // to clear the pending chunks + return Status::Ready; + } + else + { + if (selected_output_port) + selected_output_port->finish(); + else + for (auto & output_port : outputs) + output_port.finish(); + return Status::Finished; + } + } + + input.setNeeded(); + if (!input.hasData()) + return Status::NeedData; + input_chunk = input.pull(true); + has_input = true; + return Status::Ready; +} + +void StaticBranchTransform::work() +{ + if (selected_output_port) + { + if (!sample_chunks.empty()) + { + assert(!has_input); + has_output = true; + output_chunk.swap(sample_chunks.front()); + sample_chunks.pop_front(); + } + else + { + assert(has_input); + has_input = false; + has_output = true; + output_chunk.swap(input_chunk); + } + } + else if (has_input) + { + sample_rows += input_chunk.getNumRows(); + sample_chunks.emplace_back(std::move(input_chunk)); + if (sample_rows >= max_sample_rows) + setupOutputPort(); + has_input = false; + } + else if (!sample_chunks.empty()) + { + if (!selected_output_port) + setupOutputPort(); + output_chunk.swap(sample_chunks.front()); + sample_chunks.pop_front(); + has_output = true; + } +} + +void StaticBranchTransform::setupOutputPort() +{ + size_t branch_index = selector(sample_chunks); + LOG_DEBUG(getLogger("StaticBranchTransform"), "Select output port: {}", branch_index); + if (branch_index >= outputs.size()) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Branch index {} is out of range(0, {})", branch_index, outputs.size()); + auto it = outputs.begin(); + std::advance(it, branch_index); + selected_output_port = &(*it); + // close other output ports + for (auto oit = outputs.begin(); oit != outputs.end(); ++oit) + if (oit != it) + oit->finish(); +} +} // namespace local_engine diff --git a/cpp-ch/local-engine/Operator/BranchTransform.h b/cpp-ch/local-engine/Operator/BranchTransform.h new file mode 100644 index 000000000000..f5284b5ae968 --- /dev/null +++ b/cpp-ch/local-engine/Operator/BranchTransform.h @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +namespace local_engine +{ + +// This is designed for adaptive execution. It has multiple outputs, each indicates for a execution branches. +// It accepts a branch selector, this selector will analysis the input data, and select one of the output port +// as the final only output port. Other output ports will be closed. +// The output port cannot be changed once it's selected. +class StaticBranchTransform : public DB::IProcessor +{ +public: + using BranchSelector = std::function &)>; + using Status = DB::IProcessor::Status; + StaticBranchTransform(const DB::Block & header_, size_t sample_rows_, size_t branches_, BranchSelector selector_); + + String getName() const override { return "StaticBranchTransform"; } + + Status prepare() override; + void work() override; + +private: + size_t max_sample_rows; + BranchSelector selector; + DB::OutputPort * selected_output_port = nullptr; + std::list sample_chunks; + size_t sample_rows = 0; + bool has_input = false; + bool has_output = false; + DB::Chunk input_chunk; + DB::Chunk output_chunk; + + void setupOutputPort(); +}; + +}; diff --git a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp index f25e3f22ac65..81cc44cf9c08 100644 --- a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp +++ b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp @@ -17,6 +17,9 @@ #include "WindowGroupLimitStep.h" +#include +#include +#include #include #include #include @@ -38,18 +41,34 @@ enum class WindowGroupLimitFunction }; +static DB::Block buildOutputHeader(const DB::Block & input_header, bool need_output_rank_values) +{ + if (!need_output_rank_values) + return input_header; + DB::Block output_header = input_header; + auto type = std::make_shared(); + auto col = type->createColumn(); + output_header.insert(DB::ColumnWithTypeAndName(std::move(col), type, "rank_value")); + return output_header; +} + template class WindowGroupLimitTransform : public DB::IProcessor { public: using Status = DB::IProcessor::Status; explicit WindowGroupLimitTransform( - const DB::Block & header_, const std::vector & partition_columns_, const std::vector & sort_columns_, size_t limit_) - : DB::IProcessor({header_}, {header_}) + const DB::Block & header_, + const std::vector & partition_columns_, + const std::vector & sort_columns_, + size_t limit_, + bool need_output_rank_values_ = false) + : DB::IProcessor({header_}, {buildOutputHeader(header_, need_output_rank_values_)}) , header(header_) , partition_columns(partition_columns_) , sort_columns(sort_columns_) , limit(limit_) + , need_output_rank_values(need_output_rank_values_) { } @@ -95,9 +114,7 @@ class WindowGroupLimitTransform : public DB::IProcessor void work() override { if (!has_input) [[unlikely]] - { return; - } DB::Block block = header.cloneWithColumns(input_chunk.getColumns()); size_t partition_start_row = 0; size_t chunk_rows = input_chunk.getNumRows(); @@ -119,6 +136,11 @@ class WindowGroupLimitTransform : public DB::IProcessor if (!output_columns.empty() && output_columns[0]->size() > 0) { auto rows = output_columns[0]->size(); + if (rank_value_column) + { + output_columns.push_back(std::move(rank_value_column)); + rank_value_column.reset(); + } output_chunk = DB::Chunk(std::move(output_columns), rows); output_columns.clear(); has_output = true; @@ -134,11 +156,13 @@ class WindowGroupLimitTransform : public DB::IProcessor std::vector sort_columns; // Limitations for each partition. size_t limit = 0; + bool need_output_rank_values; bool has_input = false; DB::Chunk input_chunk; bool has_output = false; DB::MutableColumns output_columns; + DB::MutableColumnPtr rank_value_column = nullptr; DB::Chunk output_chunk; // We don't have window frame here. in fact all of frame are (unbounded preceding, current row] @@ -152,6 +176,12 @@ class WindowGroupLimitTransform : public DB::IProcessor DB::Columns peer_group_start_row_columns; + void tryCreateRankValueColumn() + { + if (!rank_value_column) + rank_value_column = DB::DataTypeInt32().createColumn(); + } + size_t advanceNextPartition(const DB::Chunk & chunk, size_t start_offset) { if (partition_start_row_columns.empty()) @@ -159,12 +189,8 @@ class WindowGroupLimitTransform : public DB::IProcessor size_t max_row = chunk.getNumRows(); for (size_t i = start_offset; i < max_row; ++i) - { if (!isRowEqual(partition_columns, partition_start_row_columns, 0, chunk.getColumns(), i)) - { return i; - } - } return max_row; } @@ -199,7 +225,6 @@ class WindowGroupLimitTransform : public DB::IProcessor if (current_row_rank_value > limit) return; - size_t chunk_rows = chunk.getNumRows(); auto has_peer_group_ended = [&](size_t offset, size_t partition_end_offset, size_t chunk_rows_) { return offset < partition_end_offset || end_offset < chunk_rows_; }; @@ -240,7 +265,14 @@ class WindowGroupLimitTransform : public DB::IProcessor size_t rows = end_offset - start_offset; size_t limit_remained = limit - current_row_rank_value + 1; rows = rows > limit_remained ? limit_remained : rows; + if (need_output_rank_values) + { + tryCreateRankValueColumn(); + for (Int32 i = 0; i < static_cast(rows); ++i) + typeid_cast *>(rank_value_column.get())->insertValue(current_row_rank_value + i); + } insertResultValue(chunk, start_offset, rows); + current_row_rank_value += rows; } else @@ -249,8 +281,13 @@ class WindowGroupLimitTransform : public DB::IProcessor while (peer_group_start_offset < end_offset && current_row_rank_value <= limit) { auto next_peer_group_start_offset = advanceNextPeerGroup(chunk, peer_group_start_offset, end_offset); - - insertResultValue(chunk, peer_group_start_offset, next_peer_group_start_offset - peer_group_start_offset); + size_t group_rows = next_peer_group_start_offset - peer_group_start_offset; + if (need_output_rank_values) + { + tryCreateRankValueColumn(); + rank_value_column->insertMany(current_row_rank_value, group_rows); + } + insertResultValue(chunk, peer_group_start_offset, group_rows); try_end_peer_group(peer_group_start_offset, next_peer_group_start_offset, end_offset, chunk_rows); peer_group_start_offset = next_peer_group_start_offset; } @@ -261,12 +298,8 @@ class WindowGroupLimitTransform : public DB::IProcessor if (!rows) return; if (output_columns.empty()) - { for (const auto & col : chunk.getColumns()) - { output_columns.push_back(col->cloneEmpty()); - } - } size_t i = 0; for (const auto & col : chunk.getColumns()) { @@ -279,12 +312,8 @@ class WindowGroupLimitTransform : public DB::IProcessor if (peer_group_start_row_columns.empty()) peer_group_start_row_columns = extractOneRowColumns(chunk, start_offset); for (size_t i = start_offset; i < partition_end_offset; ++i) - { if (!isRowEqual(sort_columns, peer_group_start_row_columns, 0, chunk.getColumns(), i)) - { return i; - } - } return partition_end_offset; } }; @@ -306,12 +335,14 @@ WindowGroupLimitStep::WindowGroupLimitStep( const String & function_name_, const std::vector & partition_columns_, const std::vector & sort_columns_, - size_t limit_) - : DB::ITransformingStep(input_header_, input_header_, getTraits()) + size_t limit_, + bool need_output_rank_values_) + : DB::ITransformingStep(input_header_, buildOutputHeader(input_header_, need_output_rank_values_), getTraits()) , function_name(function_name_) , partition_columns(partition_columns_) , sort_columns(sort_columns_) , limit(limit_) + , need_output_rank_values(need_output_rank_values_) { } @@ -335,7 +366,7 @@ void WindowGroupLimitStep::transformPipeline(DB::QueryPipelineBuilder & pipeline [&](const DB::Block & header) { return std::make_shared>( - header, partition_columns, sort_columns, limit); + header, partition_columns, sort_columns, limit, need_output_rank_values); }); } else if (function_name == "rank") diff --git a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h index 55e3eaeb72d3..e7658592a682 100644 --- a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h +++ b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h @@ -31,7 +31,8 @@ class WindowGroupLimitStep : public DB::ITransformingStep const String & function_name_, const std::vector & partition_columns_, const std::vector & sort_columns_, - size_t limit_); + size_t limit_, + bool need_output_rank_values_ = false); ~WindowGroupLimitStep() override = default; String getName() const override { return "WindowGroupLimitStep"; } @@ -46,6 +47,7 @@ class WindowGroupLimitStep : public DB::ITransformingStep std::vector partition_columns; std::vector sort_columns; size_t limit; + bool need_output_rank_values; }; } diff --git a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp index 8637d987228b..1f9b8f759be8 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp @@ -16,7 +16,10 @@ */ #include "GroupLimitRelParser.h" +#include #include +#include +#include #include #include #include @@ -26,22 +29,44 @@ #include #include #include +#include #include #include #include +#include +#include +#include +#include #include +#include #include +#include +#include +#include +#include #include #include #include #include #include +#include +#include +#include namespace DB::ErrorCodes { extern const int BAD_ARGUMENTS; } +namespace DB +{ +namespace Setting +{ +extern const SettingsMaxThreads max_threads; + +} +} + namespace local_engine { GroupLimitRelParser::GroupLimitRelParser(ParserContextPtr parser_context_) : RelParser(parser_context_) @@ -71,6 +96,33 @@ GroupLimitRelParser::parse(DB::QueryPlanPtr current_plan_, const substrait::Rel } } +static std::vector parsePartitionFields(const google::protobuf::RepeatedPtrField & expressions) +{ + std::vector fields; + for (const auto & expr : expressions) + if (expr.has_selection()) + fields.push_back(static_cast(expr.selection().direct_reference().struct_field().field())); + else if (expr.has_literal()) + continue; + else + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow expression: {}", expr.DebugString()); + return fields; +} + +std::vector parseSortFields(const google::protobuf::RepeatedPtrField & sort_fields) +{ + std::vector fields; + for (const auto sort_field : sort_fields) + if (sort_field.expr().has_literal()) + continue; + else if (sort_field.expr().has_selection()) + fields.push_back(static_cast(sort_field.expr().selection().direct_reference().struct_field().field())); + else + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown expression: {}", sort_field.expr().DebugString()); + return fields; +} + + WindowGroupLimitRelParser::WindowGroupLimitRelParser(ParserContextPtr parser_context_) : RelParser(parser_context_) { } @@ -86,7 +138,7 @@ WindowGroupLimitRelParser::parse(DB::QueryPlanPtr current_plan_, const substrait current_plan = std::move(current_plan_); - auto partition_fields = parsePartitoinFields(win_rel_def.partition_expressions()); + auto partition_fields = parsePartitionFields(win_rel_def.partition_expressions()); auto sort_fields = parseSortFields(win_rel_def.sorts()); size_t limit = static_cast(win_rel_def.limit()); @@ -99,35 +151,34 @@ WindowGroupLimitRelParser::parse(DB::QueryPlanPtr current_plan_, const substrait return std::move(current_plan); } -std::vector -WindowGroupLimitRelParser::parsePartitoinFields(const google::protobuf::RepeatedPtrField & expressions) -{ - std::vector fields; - for (const auto & expr : expressions) - if (expr.has_selection()) - fields.push_back(static_cast(expr.selection().direct_reference().struct_field().field())); - else if (expr.has_literal()) - continue; - else - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow expression: {}", expr.DebugString()); - return fields; -} - -std::vector WindowGroupLimitRelParser::parseSortFields(const google::protobuf::RepeatedPtrField & sort_fields) +AggregateGroupLimitRelParser::AggregateGroupLimitRelParser(ParserContextPtr parser_context_) : RelParser(parser_context_) { - std::vector fields; - for (const auto sort_field : sort_fields) - if (sort_field.expr().has_literal()) - continue; - else if (sort_field.expr().has_selection()) - fields.push_back(static_cast(sort_field.expr().selection().direct_reference().struct_field().field())); - else - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown expression: {}", sort_field.expr().DebugString()); - return fields; } -AggregateGroupLimitRelParser::AggregateGroupLimitRelParser(ParserContextPtr parser_context_) : RelParser(parser_context_) +// used to decide which branch +size_t selectBranchOnPartitionKeysCardinality( + const std::vector & partition_keys, double high_card_threshold, const std::list & chunks) { + size_t total_rows = 0; + std::unordered_set ids; + for (const auto & chunk : chunks) + { + total_rows += chunk.getNumRows(); + DB::WeakHash32 hash(chunk.getNumRows()); + const auto & cols = chunk.getColumns(); + for (auto i : partition_keys) + hash.update(cols[i]->getWeakHash32()); + const auto & data = hash.getData(); + for (size_t n = 0, sz = chunk.getNumRows(); n < sz; ++n) + ids.insert(data[n]); + } + LOG_DEBUG( + getLogger("AggregateGroupLimitRelParser"), + "Approximate distinct keys {}, total rows: {}, thrshold: {}", + ids.size(), + total_rows, + high_card_threshold); + return ids.size() * 1.0 / (total_rows + 1) <= high_card_threshold ? 0 : 1; } DB::QueryPlanPtr AggregateGroupLimitRelParser::parse( @@ -143,21 +194,60 @@ DB::QueryPlanPtr AggregateGroupLimitRelParser::parse( input_header = current_plan->getCurrentHeader(); win_rel_def = &rel.windowgrouplimit(); - const auto win_rel_def = rel.windowgrouplimit(); google::protobuf::StringValue optimize_info_str; - optimize_info_str.ParseFromString(win_rel_def.advanced_extension().optimization().value()); + optimize_info_str.ParseFromString(win_rel_def->advanced_extension().optimization().value()); auto optimization_info = WindowGroupOptimizationInfo::parse(optimize_info_str.value()); - limit = static_cast(win_rel_def.limit()); + limit = static_cast(win_rel_def->limit()); aggregate_function_name = getAggregateFunctionName(optimization_info.window_function); if (limit < 1) throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Invalid limit: {}", limit); - prePrejectionForAggregateArguments(); - addGroupLmitAggregationStep(); - postProjectionForExplodingArrays(); - - LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Final group limit plan:\n{}", PlanUtil::explainPlan(*current_plan)); + auto win_config = WindowConfig::loadFromContext(getContext()); + auto high_card_threshold = win_config.aggregate_topk_high_cardinality_threshold; + + auto partition_fields = parsePartitionFields(win_rel_def->partition_expressions()); + auto branch_in_header = current_plan->getCurrentHeader(); + auto branch_step = std::make_unique( + getContext(), + branch_in_header, + 2, + win_config.aggregate_topk_sample_rows, + [partition_fields, high_card_threshold](const std::list & chunks) -> size_t + { return selectBranchOnPartitionKeysCardinality(partition_fields, high_card_threshold, chunks); }); + branch_step->setStepDescription("Window TopK"); + steps.push_back(branch_step.get()); + current_plan->addStep(std::move(branch_step)); + + // If all partition keys are low cardinality keys, use aggregattion to get topk of each partition + auto aggregation_plan = BranchStepHelper::createSubPlan(branch_in_header, 1); + prePrejectionForAggregateArguments(*aggregation_plan); + addGroupLmitAggregationStep(*aggregation_plan); + postProjectionForExplodingArrays(*aggregation_plan); + LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Aggregate topk plan:\n{}", PlanUtil::explainPlan(*aggregation_plan)); + + + // aggregation doesn't performs well on high cardinality keys. use sort + window to get the topks. + auto window_plan = BranchStepHelper::createSubPlan(branch_in_header, 1); + addSortStep(*window_plan); + addWindowLimitStep(*window_plan); + auto convert_actions_dag = DB::ActionsDAG::makeConvertingActions( + window_plan->getCurrentHeader().getColumnsWithTypeAndName(), + aggregation_plan->getCurrentHeader().getColumnsWithTypeAndName(), + DB::ActionsDAG::MatchColumnsMode::Position); + auto convert_step = std::make_unique(window_plan->getCurrentHeader(), std::move(convert_actions_dag)); + convert_step->setStepDescription("Rename rank column name"); + window_plan->addStep(std::move(convert_step)); + LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Window topk plan:\n{}", PlanUtil::explainPlan(*window_plan)); + + std::vector branch_plans; + branch_plans.emplace_back(std::move(aggregation_plan)); + branch_plans.emplace_back(std::move(window_plan)); + auto unite_branches_step = std::make_unique(branch_in_header, std::move(branch_plans), 1); + unite_branches_step->setStepDescription("Unite TopK branches"); + steps.push_back(unite_branches_step.get()); + + current_plan->addStep(std::move(unite_branches_step)); return std::move(current_plan); } @@ -165,36 +255,18 @@ String AggregateGroupLimitRelParser::getAggregateFunctionName(const String & win { if (window_function_name == "row_number") return "rowNumGroupArraySorted"; -#if 0 - else if (window_function_name == "rank") - return "groupArrayRankSorted"; - else if (window_function_name == "dense_rank") - return "groupArrayDenseRankSorted"; -#endif else throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unsupported window function: {}", window_function_name); } -static std::set collectPartitionFields(const google::protobuf::RepeatedPtrField & expressions) -{ - std::set fields; - for (const auto & expr : expressions) - if (expr.has_selection()) - fields.insert(static_cast(expr.selection().direct_reference().struct_field().field())); - else if (expr.has_literal()) - continue; - else - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow expression: {}", expr.DebugString()); - return fields; -} - // Build two tuple columns as the aggregate function's arguments -void AggregateGroupLimitRelParser::prePrejectionForAggregateArguments() +void AggregateGroupLimitRelParser::prePrejectionForAggregateArguments(DB::QueryPlan & plan) { auto projection_actions = std::make_shared(input_header.getColumnsWithTypeAndName()); - auto partition_fields = collectPartitionFields(win_rel_def->partition_expressions()); + auto partition_fields = parsePartitionFields(win_rel_def->partition_expressions()); + std::set unique_partition_fields(partition_fields.begin(), partition_fields.end()); DB::NameSet required_column_names; auto build_tuple = [&](const DB::DataTypes & data_types, const Strings & names, @@ -220,7 +292,7 @@ void AggregateGroupLimitRelParser::prePrejectionForAggregateArguments() for (size_t i = 0; i < input_header.columns(); ++i) { const auto & col = input_header.getByPosition(i); - if (partition_fields.count(i)) + if (unique_partition_fields.count(i)) { required_column_names.insert(col.name); aggregate_grouping_keys.push_back(col.name); @@ -247,8 +319,7 @@ void AggregateGroupLimitRelParser::prePrejectionForAggregateArguments() auto expression_step = std::make_unique(input_header, std::move(*projection_actions)); expression_step->setStepDescription("Pre-projection for aggregate group limit arguments"); - steps.push_back(expression_step.get()); - current_plan->addStep(std::move(expression_step)); + plan.addStep(std::move(expression_step)); } @@ -281,7 +352,7 @@ String AggregateGroupLimitRelParser::parseSortDirections(const google::protobuf: return ostr.str(); } -DB::AggregateDescription AggregateGroupLimitRelParser::buildAggregateDescription() +DB::AggregateDescription AggregateGroupLimitRelParser::buildAggregateDescription(DB::QueryPlan & plan) { DB::AggregateDescription agg_desc; agg_desc.column_name = aggregate_tuple_column_name; @@ -291,7 +362,7 @@ DB::AggregateDescription AggregateGroupLimitRelParser::buildAggregateDescription auto sort_directions = parseSortDirections(win_rel_def->sorts()); parameters.push_back(sort_directions); - auto header = current_plan->getCurrentHeader(); + auto header = plan.getCurrentHeader(); DB::DataTypes arg_types; arg_types.push_back(header.getByName(aggregate_tuple_column_name).type); @@ -300,29 +371,27 @@ DB::AggregateDescription AggregateGroupLimitRelParser::buildAggregateDescription return agg_desc; } -void AggregateGroupLimitRelParser::addGroupLmitAggregationStep() +void AggregateGroupLimitRelParser::addGroupLmitAggregationStep(DB::QueryPlan & plan) { const auto & settings = getContext()->getSettingsRef(); - DB::AggregateDescriptions agg_descs = {buildAggregateDescription()}; + DB::AggregateDescriptions agg_descs = {buildAggregateDescription(plan)}; auto params = AggregatorParamsHelper::buildParams( getContext(), aggregate_grouping_keys, agg_descs, AggregatorParamsHelper::Mode::INIT_TO_COMPLETED); - auto agg_step = std::make_unique(getContext(), current_plan->getCurrentHeader(), params, true); - steps.push_back(agg_step.get()); - current_plan->addStep(std::move(agg_step)); - LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Plan after add group limit:\n{}", PlanUtil::explainPlan(*current_plan)); + auto agg_step = std::make_unique(getContext(), plan.getCurrentHeader(), params, true); + plan.addStep(std::move(agg_step)); + LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Plan after add group limit:\n{}", PlanUtil::explainPlan(plan)); } -void AggregateGroupLimitRelParser::postProjectionForExplodingArrays() +void AggregateGroupLimitRelParser::postProjectionForExplodingArrays(DB::QueryPlan & plan) { - auto header = current_plan->getCurrentHeader(); + auto header = plan.getCurrentHeader(); /// flatten the array column. auto agg_result_index = header.columns() - 1; auto array_join_actions_dag = ArrayJoinHelper::applyArrayJoinOnOneColumn(header, agg_result_index); - auto new_steps = ArrayJoinHelper::addArrayJoinStep(getContext(), *current_plan, array_join_actions_dag, false); - steps.insert(steps.end(), new_steps.begin(), new_steps.end()); + auto new_steps = ArrayJoinHelper::addArrayJoinStep(getContext(), plan, array_join_actions_dag, false); - auto array_join_output_header = current_plan->getCurrentHeader(); + auto array_join_output_header = plan.getCurrentHeader(); DB::ActionsDAG flatten_actions_dag(array_join_output_header.getColumnsWithTypeAndName()); DB::Names flatten_output_column_names; for (size_t i = 0; i < array_join_output_header.columns() - 1; ++i) @@ -347,12 +416,11 @@ void AggregateGroupLimitRelParser::postProjectionForExplodingArrays() } flatten_actions_dag.removeUnusedActions(flatten_output_column_names); LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Actions dag for untupling aggregate result:\n{}", flatten_actions_dag.dumpDAG()); - auto flatten_expression_step = std::make_unique(current_plan->getCurrentHeader(), std::move(flatten_actions_dag)); + auto flatten_expression_step = std::make_unique(plan.getCurrentHeader(), std::move(flatten_actions_dag)); flatten_expression_step->setStepDescription("Untuple the aggregation result"); - steps.push_back(flatten_expression_step.get()); - current_plan->addStep(std::move(flatten_expression_step)); + plan.addStep(std::move(flatten_expression_step)); - auto flatten_tuple_output_header = current_plan->getCurrentHeader(); + auto flatten_tuple_output_header = plan.getCurrentHeader(); auto window_result_column = flatten_tuple_output_header.getByPosition(flatten_tuple_output_header.columns() - 1); /// The result column is put at the end of the header. auto output_header = input_header; @@ -364,8 +432,46 @@ void AggregateGroupLimitRelParser::postProjectionForExplodingArrays() LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Actions dag for replacing columns:\n{}", adjust_pos_actions_dag.dumpDAG()); auto adjust_pos_expression_step = std::make_unique(flatten_tuple_output_header, std::move(adjust_pos_actions_dag)); adjust_pos_expression_step->setStepDescription("Adjust position of the output columns"); - steps.push_back(adjust_pos_expression_step.get()); - current_plan->addStep(std::move(adjust_pos_expression_step)); + plan.addStep(std::move(adjust_pos_expression_step)); +} + +void AggregateGroupLimitRelParser::addSortStep(DB::QueryPlan & plan) +{ + auto header = plan.getCurrentHeader(); + DB::SortDescription full_sort_descr; + auto partition_fields = parsePartitionFields(win_rel_def->partition_expressions()); + for (auto field : partition_fields) + { + const auto & col = header.getByPosition(field); + full_sort_descr.emplace_back(col.name, 1, -1); + } + auto sort_desrc = SortRelParser::parseSortDescription(win_rel_def->sorts(), header); + full_sort_descr.insert(full_sort_descr.end(), sort_desrc.begin(), sort_desrc.end()); + + DB::SortingStep::Settings settings(*getContext()); + auto config = MemoryConfig::loadFromContext(getContext()); + double spill_mem_ratio = config.spill_mem_ratio; + settings.worth_external_sort = [spill_mem_ratio]() -> bool { return currentThreadGroupMemoryUsageRatio() > spill_mem_ratio; }; + auto sorting_step = std::make_unique(plan.getCurrentHeader(), full_sort_descr, 0, settings); + sorting_step->setStepDescription("Sorting step"); + plan.addStep(std::move(sorting_step)); +} + +void AggregateGroupLimitRelParser::addWindowLimitStep(DB::QueryPlan & plan) +{ + google::protobuf::StringValue optimize_info_str; + optimize_info_str.ParseFromString(win_rel_def->advanced_extension().optimization().value()); + auto optimization_info = WindowGroupOptimizationInfo::parse(optimize_info_str.value()); + auto window_function_name = optimization_info.window_function; + + auto partition_fields = parsePartitionFields(win_rel_def->partition_expressions()); + auto sort_fields = parseSortFields(win_rel_def->sorts()); + size_t limit = static_cast(win_rel_def->limit()); + + auto window_group_limit_step + = std::make_unique(plan.getCurrentHeader(), window_function_name, partition_fields, sort_fields, limit, true); + window_group_limit_step->setStepDescription("Window group limit"); + plan.addStep(std::move(window_group_limit_step)); } void registerWindowGroupLimitRelParser(RelParserFactory & factory) diff --git a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h index b8c71c819b91..b9f3aa6631c3 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h +++ b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h @@ -52,9 +52,6 @@ class WindowGroupLimitRelParser : public RelParser private: DB::QueryPlanPtr current_plan; String window_function_name; - - std::vector parsePartitoinFields(const google::protobuf::RepeatedPtrField & expressions); - std::vector parseSortFields(const google::protobuf::RepeatedPtrField & sort_fields); }; class AggregateGroupLimitRelParser : public RelParser @@ -78,12 +75,14 @@ class AggregateGroupLimitRelParser : public RelParser String getAggregateFunctionName(const String & window_function_name); - void prePrejectionForAggregateArguments(); + void prePrejectionForAggregateArguments(DB::QueryPlan & plan); - void addGroupLmitAggregationStep(); + void addGroupLmitAggregationStep(DB::QueryPlan & plan); String parseSortDirections(const google::protobuf::RepeatedPtrField & sort_fields); - DB::AggregateDescription buildAggregateDescription(); + DB::AggregateDescription buildAggregateDescription(DB::QueryPlan & plan); + void postProjectionForExplodingArrays(DB::QueryPlan & plan); - void postProjectionForExplodingArrays(); + void addSortStep(DB::QueryPlan & plan); + void addWindowLimitStep(DB::QueryPlan & plan); }; }