Skip to content

Commit

Permalink
window group limit has bad performance
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Nov 26, 2024
1 parent ab37862 commit 630e144
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,9 @@ import org.apache.gluten.expression.WindowFunctionsBuilder
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions._
// import org.apache.spark.sql.catalyst.expressions.aggregate._
// import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.SparkPlan
// import org.apache.spark.sql.execution.window.Final
import org.apache.spark.sql.types._

// When to find the first rows of partitions by window function, we can convert it to aggregate
Expand Down Expand Up @@ -103,11 +100,9 @@ case class ConverRowNumbertWindowToAggregateRule(spark: SparkSession)

def isSupportedWindowFunction(windowExpressions: Seq[NamedExpression]): Boolean = {
if (windowExpressions.length != 1) {
logDebug(s"xxx windowExpressions length: ${windowExpressions.length}")
return false
}
val windowFunction = extractWindowFunction(windowExpressions(0))
logDebug(s"xxx windowFunction: $windowFunction")
windowFunction match {
case _: RowNumber => true
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3168,9 +3168,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
(runtimeConfigPrefix + "window.aggregate_topk_high_cardinality_threshold", "2.0")) {
def checkWindowGroupLimit(df: DataFrame): Unit = {
val expands = collectWithSubqueries(df.queryExecution.executedPlan) {
case e: ExpandExecTransformer
if (e.child.isInstanceOf[CHAggregateGroupLimitExecTransformer]) =>
e
case e: CHAggregateGroupLimitExecTransformer => e
}
assert(expands.size == 1)
}
Expand Down Expand Up @@ -3243,9 +3241,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
(runtimeConfigPrefix + "window.aggregate_topk_high_cardinality_threshold", "0.0")) {
def checkWindowGroupLimit(df: DataFrame): Unit = {
val expands = collectWithSubqueries(df.queryExecution.executedPlan) {
case e: ExpandExecTransformer
if (e.child.isInstanceOf[CHAggregateGroupLimitExecTransformer]) =>
e
case e: CHAggregateGroupLimitExecTransformer => e
}
assert(expands.size == 1)
}
Expand Down
105 changes: 105 additions & 0 deletions cpp-ch/local-engine/Common/SortUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "SortUtils.h"
#include <IO/Operators.h>
#include <IO/WriteBufferFromString.h>
#include <Poco/Logger.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>

namespace DB::ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int LOGICAL_ERROR;
}

namespace local_engine
{
DB::SortDescription parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::Expression> & expressions)
{
DB::SortDescription description;
for (const auto & expr : expressions)
if (expr.has_selection())
{
auto pos = expr.selection().direct_reference().struct_field().field();
const auto & col_name = header.getByPosition(pos).name;
description.push_back(DB::SortColumnDescription(col_name, 1, 1));
}
else if (expr.has_literal())
continue;
else
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow expression as sort field: {}", expr.DebugString());
return description;
}

DB::SortDescription parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields)
{
static std::map<int, std::pair<int, int>> direction_map = {{1, {1, -1}}, {2, {1, 1}}, {3, {-1, 1}}, {4, {-1, -1}}};

DB::SortDescription sort_descr;
for (int i = 0, sz = sort_fields.size(); i < sz; ++i)
{
const auto & sort_field = sort_fields[i];
/// There is no meaning to sort a const column.
if (sort_field.expr().has_literal())
continue;

if (!sort_field.expr().has_selection() || !sort_field.expr().selection().has_direct_reference()
|| !sort_field.expr().selection().direct_reference().has_struct_field())
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupport sort field");
}
auto field_pos = sort_field.expr().selection().direct_reference().struct_field().field();

auto direction_iter = direction_map.find(sort_field.direction());
if (direction_iter == direction_map.end())
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsuppor sort direction: {}", sort_field.direction());
const auto & col_name = header.getByPosition(field_pos).name;
sort_descr.emplace_back(col_name, direction_iter->second.first, direction_iter->second.second);
}
return sort_descr;
}

std::string
buildSQLLikeSortDescription(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields)
{
static const std::unordered_map<int, std::string> order_directions
= {{1, " asc nulls first"}, {2, " asc nulls last"}, {3, " desc nulls first"}, {4, " desc nulls last"}};
size_t n = 0;
DB::WriteBufferFromOwnString ostr;
for (const auto & sort_field : sort_fields)
{
auto it = order_directions.find(sort_field.direction());
if (it == order_directions.end())
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow sort direction: {}", sort_field.direction());
if (!sort_field.expr().has_selection())
{
throw DB::Exception(
DB::ErrorCodes::BAD_ARGUMENTS, "Sort field must be a column reference. but got {}", sort_field.DebugString());
}
auto ref = sort_field.expr().selection().direct_reference().struct_field().field();
const auto & col_name = header.getByPosition(ref).name;
if (n)
ostr << String(",");
// the col_name may contain '#' which can may ch fail to parse.
ostr << "`" << col_name << "`" << it->second;
n += 1;
}
LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Order by clasue: {}", ostr.str());
return ostr.str();
}
}
33 changes: 33 additions & 0 deletions cpp-ch/local-engine/Common/SortUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <Core/Block.h>
#include <Core/SortDescription.h>
#include <google/protobuf/repeated_field.h>
#include <substrait/plan.pb.h>

namespace local_engine
{
// convert expressions into sort description
DB::SortDescription
parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::Expression> & expressions);
DB::SortDescription parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields);

std::string
buildSQLLikeSortDescription(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields);
}
8 changes: 8 additions & 0 deletions cpp-ch/local-engine/Operator/BranchStep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ void UniteBranchesStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, c
{
DB::Processors new_processors;
size_t branch_index = 0;
if (child_outputs.size() != branch_plans.size())
{
throw DB::Exception(
DB::ErrorCodes::LOGICAL_ERROR,
"Output port's size({}) is not equal to branches size({})",
child_outputs.size(),
branch_plans.size());
}
for (auto output : child_outputs)
{
auto & branch_plan = branch_plans[branch_index];
Expand Down
55 changes: 5 additions & 50 deletions cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,36 +40,18 @@ enum class WindowGroupLimitFunction
DenseRank
};


static DB::Block buildOutputHeader(const DB::Block & input_header, bool need_output_rank_values)
{
if (!need_output_rank_values)
return input_header;
DB::Block output_header = input_header;
auto type = std::make_shared<DB::DataTypeInt32>();
auto col = type->createColumn();
output_header.insert(DB::ColumnWithTypeAndName(std::move(col), type, "rank_value"));
return output_header;
}

template <WindowGroupLimitFunction function>
class WindowGroupLimitTransform : public DB::IProcessor
{
public:
using Status = DB::IProcessor::Status;
explicit WindowGroupLimitTransform(
const DB::Block & header_,
const std::vector<size_t> & partition_columns_,
const std::vector<size_t> & sort_columns_,
size_t limit_,
bool need_output_rank_values_ = false)
: DB::IProcessor({header_}, {buildOutputHeader(header_, need_output_rank_values_)})
const DB::Block & header_, const std::vector<size_t> & partition_columns_, const std::vector<size_t> & sort_columns_, size_t limit_)
: DB::IProcessor({header_}, {header_})
, header(header_)
, partition_columns(partition_columns_)
, sort_columns(sort_columns_)
, limit(limit_)
, need_output_rank_values(need_output_rank_values_)

{
}
~WindowGroupLimitTransform() override = default;
Expand Down Expand Up @@ -136,11 +118,6 @@ class WindowGroupLimitTransform : public DB::IProcessor
if (!output_columns.empty() && output_columns[0]->size() > 0)
{
auto rows = output_columns[0]->size();
if (rank_value_column)
{
output_columns.push_back(std::move(rank_value_column));
rank_value_column.reset();
}
output_chunk = DB::Chunk(std::move(output_columns), rows);
output_columns.clear();
has_output = true;
Expand All @@ -156,13 +133,11 @@ class WindowGroupLimitTransform : public DB::IProcessor
std::vector<size_t> sort_columns;
// Limitations for each partition.
size_t limit = 0;
bool need_output_rank_values;

bool has_input = false;
DB::Chunk input_chunk;
bool has_output = false;
DB::MutableColumns output_columns;
DB::MutableColumnPtr rank_value_column = nullptr;
DB::Chunk output_chunk;

// We don't have window frame here. in fact all of frame are (unbounded preceding, current row]
Expand All @@ -175,13 +150,6 @@ class WindowGroupLimitTransform : public DB::IProcessor
DB::Columns partition_start_row_columns;
DB::Columns peer_group_start_row_columns;


void tryCreateRankValueColumn()
{
if (!rank_value_column)
rank_value_column = DB::DataTypeInt32().createColumn();
}

size_t advanceNextPartition(const DB::Chunk & chunk, size_t start_offset)
{
if (partition_start_row_columns.empty())
Expand Down Expand Up @@ -265,12 +233,6 @@ class WindowGroupLimitTransform : public DB::IProcessor
size_t rows = end_offset - start_offset;
size_t limit_remained = limit - current_row_rank_value + 1;
rows = rows > limit_remained ? limit_remained : rows;
if (need_output_rank_values)
{
tryCreateRankValueColumn();
for (Int32 i = 0; i < static_cast<Int32>(rows); ++i)
typeid_cast<DB::ColumnVector<Int32> *>(rank_value_column.get())->insertValue(current_row_rank_value + i);
}
insertResultValue(chunk, start_offset, rows);

current_row_rank_value += rows;
Expand All @@ -282,11 +244,6 @@ class WindowGroupLimitTransform : public DB::IProcessor
{
auto next_peer_group_start_offset = advanceNextPeerGroup(chunk, peer_group_start_offset, end_offset);
size_t group_rows = next_peer_group_start_offset - peer_group_start_offset;
if (need_output_rank_values)
{
tryCreateRankValueColumn();
rank_value_column->insertMany(current_row_rank_value, group_rows);
}
insertResultValue(chunk, peer_group_start_offset, group_rows);
try_end_peer_group(peer_group_start_offset, next_peer_group_start_offset, end_offset, chunk_rows);
peer_group_start_offset = next_peer_group_start_offset;
Expand Down Expand Up @@ -335,14 +292,12 @@ WindowGroupLimitStep::WindowGroupLimitStep(
const String & function_name_,
const std::vector<size_t> & partition_columns_,
const std::vector<size_t> & sort_columns_,
size_t limit_,
bool need_output_rank_values_)
: DB::ITransformingStep(input_header_, buildOutputHeader(input_header_, need_output_rank_values_), getTraits())
size_t limit_)
: DB::ITransformingStep(input_header_, input_header_, getTraits())
, function_name(function_name_)
, partition_columns(partition_columns_)
, sort_columns(sort_columns_)
, limit(limit_)
, need_output_rank_values(need_output_rank_values_)
{
}

Expand All @@ -366,7 +321,7 @@ void WindowGroupLimitStep::transformPipeline(DB::QueryPipelineBuilder & pipeline
[&](const DB::Block & header)
{
return std::make_shared<WindowGroupLimitTransform<WindowGroupLimitFunction::RowNumber>>(
header, partition_columns, sort_columns, limit, need_output_rank_values);
header, partition_columns, sort_columns, limit);
});
}
else if (function_name == "rank")
Expand Down
4 changes: 1 addition & 3 deletions cpp-ch/local-engine/Operator/WindowGroupLimitStep.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ class WindowGroupLimitStep : public DB::ITransformingStep
const String & function_name_,
const std::vector<size_t> & partition_columns_,
const std::vector<size_t> & sort_columns_,
size_t limit_,
bool need_output_rank_values_ = false);
size_t limit_);
~WindowGroupLimitStep() override = default;

String getName() const override { return "WindowGroupLimitStep"; }
Expand All @@ -47,7 +46,6 @@ class WindowGroupLimitStep : public DB::ITransformingStep
std::vector<size_t> partition_columns;
std::vector<size_t> sort_columns;
size_t limit;
bool need_output_rank_values;
};

}
Loading

0 comments on commit 630e144

Please sign in to comment.