Skip to content

Commit

Permalink
window group limit has bad performance
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Nov 26, 2024
1 parent ab37862 commit 73f7b6d
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,9 @@ import org.apache.gluten.expression.WindowFunctionsBuilder
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions._
// import org.apache.spark.sql.catalyst.expressions.aggregate._
// import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
// import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.SparkPlan
// import org.apache.spark.sql.execution.window.Final
import org.apache.spark.sql.types._

// When to find the first rows of partitions by window function, we can convert it to aggregate
Expand Down Expand Up @@ -103,11 +100,9 @@ case class ConverRowNumbertWindowToAggregateRule(spark: SparkSession)

def isSupportedWindowFunction(windowExpressions: Seq[NamedExpression]): Boolean = {
if (windowExpressions.length != 1) {
logDebug(s"xxx windowExpressions length: ${windowExpressions.length}")
return false
}
val windowFunction = extractWindowFunction(windowExpressions(0))
logDebug(s"xxx windowFunction: $windowFunction")
windowFunction match {
case _: RowNumber => true
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3168,9 +3168,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
(runtimeConfigPrefix + "window.aggregate_topk_high_cardinality_threshold", "2.0")) {
def checkWindowGroupLimit(df: DataFrame): Unit = {
val expands = collectWithSubqueries(df.queryExecution.executedPlan) {
case e: ExpandExecTransformer
if (e.child.isInstanceOf[CHAggregateGroupLimitExecTransformer]) =>
e
case e: CHAggregateGroupLimitExecTransformer => e
}
assert(expands.size == 1)
}
Expand Down Expand Up @@ -3243,9 +3241,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
(runtimeConfigPrefix + "window.aggregate_topk_high_cardinality_threshold", "0.0")) {
def checkWindowGroupLimit(df: DataFrame): Unit = {
val expands = collectWithSubqueries(df.queryExecution.executedPlan) {
case e: ExpandExecTransformer
if (e.child.isInstanceOf[CHAggregateGroupLimitExecTransformer]) =>
e
case e: CHAggregateGroupLimitExecTransformer => e
}
assert(expands.size == 1)
}
Expand Down
105 changes: 105 additions & 0 deletions cpp-ch/local-engine/Common/SortUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "SortUtils.h"
#include <IO/Operators.h>
#include <IO/WriteBufferFromString.h>
#include <Poco/Logger.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>

namespace DB::ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int LOGICAL_ERROR;
}

namespace local_engine
{
DB::SortDescription parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::Expression> & expressions)
{
DB::SortDescription description;
for (const auto & expr : expressions)
if (expr.has_selection())
{
auto pos = expr.selection().direct_reference().struct_field().field();
const auto & col_name = header.getByPosition(pos).name;
description.push_back(DB::SortColumnDescription(col_name, 1, 1));
}
else if (expr.has_literal())
continue;
else
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow expression as sort field: {}", expr.DebugString());
return description;
}

DB::SortDescription parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields)
{
static std::map<int, std::pair<int, int>> direction_map = {{1, {1, -1}}, {2, {1, 1}}, {3, {-1, 1}}, {4, {-1, -1}}};

DB::SortDescription sort_descr;
for (int i = 0, sz = sort_fields.size(); i < sz; ++i)
{
const auto & sort_field = sort_fields[i];
/// There is no meaning to sort a const column.
if (sort_field.expr().has_literal())
continue;

if (!sort_field.expr().has_selection() || !sort_field.expr().selection().has_direct_reference()
|| !sort_field.expr().selection().direct_reference().has_struct_field())
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupport sort field");
}
auto field_pos = sort_field.expr().selection().direct_reference().struct_field().field();

auto direction_iter = direction_map.find(sort_field.direction());
if (direction_iter == direction_map.end())
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsuppor sort direction: {}", sort_field.direction());
const auto & col_name = header.getByPosition(field_pos).name;
sort_descr.emplace_back(col_name, direction_iter->second.first, direction_iter->second.second);
}
return sort_descr;
}

std::string
buildSQLLikeSortDescription(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields)
{
static const std::unordered_map<int, std::string> order_directions
= {{1, " asc nulls first"}, {2, " asc nulls last"}, {3, " desc nulls first"}, {4, " desc nulls last"}};
size_t n = 0;
DB::WriteBufferFromOwnString ostr;
for (const auto & sort_field : sort_fields)
{
auto it = order_directions.find(sort_field.direction());
if (it == order_directions.end())
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow sort direction: {}", sort_field.direction());
if (!sort_field.expr().has_selection())
{
throw DB::Exception(
DB::ErrorCodes::BAD_ARGUMENTS, "Sort field must be a column reference. but got {}", sort_field.DebugString());
}
auto ref = sort_field.expr().selection().direct_reference().struct_field().field();
const auto & col_name = header.getByPosition(ref).name;
if (n)
ostr << String(",");
// the col_name may contain '#' which can may ch fail to parse.
ostr << "`" << col_name << "`" << it->second;
n += 1;
}
LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Order by clasue: {}", ostr.str());
return ostr.str();
}
}
33 changes: 33 additions & 0 deletions cpp-ch/local-engine/Common/SortUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <Core/Block.h>
#include <Core/SortDescription.h>
#include <google/protobuf/repeated_field.h>
#include <substrait/plan.pb.h>

namespace local_engine
{
// convert expressions into sort description
DB::SortDescription
parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::Expression> & expressions);
DB::SortDescription parseSortFields(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields);

std::string
buildSQLLikeSortDescription(const DB::Block & header, const google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields);
}
77 changes: 8 additions & 69 deletions cpp-ch/local-engine/Operator/BranchStep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,68 +35,6 @@
namespace local_engine
{

class BranchOutputTransform : public DB::IProcessor
{
public:
using Status = DB::IProcessor::Status;
BranchOutputTransform(const DB::Block & header_) : DB::IProcessor({header_}, {header_}) { }
~BranchOutputTransform() override = default;

String getName() const override { return "BranchOutputTransform"; }

Status prepare() override;
void work() override;

private:
bool has_output = false;
DB::Chunk output_chunk;
bool has_input = false;
DB::Chunk input_chunk;
};

BranchOutputTransform::Status BranchOutputTransform::prepare()
{
auto & output = outputs.front();
auto & input = inputs.front();
if (output.isFinished())
{
input.close();
return Status::Finished;
}
if (has_output)
{
if (output.canPush())
{
output.push(std::move(output_chunk));
has_output = false;
}
return Status::PortFull;
}
if (has_input)
return Status::Ready;
if (input.isFinished())
{
output.finish();
return Status::Finished;
}
input.setNeeded();
if (!input.hasData())
return Status::NeedData;
input_chunk = input.pull(true);
has_input = true;
return Status::Ready;
}

void BranchOutputTransform::work()
{
if (has_input)
{
output_chunk = std::move(input_chunk);
has_output = true;
has_input = false;
}
}

class BranchHookSource : public DB::IProcessor
{
public:
Expand Down Expand Up @@ -240,13 +178,6 @@ void StaticBranchStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, co
auto branch_transform = std::make_shared<StaticBranchTransform>(header, max_sample_rows, branches, selector);
DB::connect(*output, branch_transform->getInputs().front());
new_processors.push_back(branch_transform);

for (auto & branch_output : branch_transform->getOutputs())
{
auto branch_processor = std::make_shared<BranchOutputTransform>(header);
DB::connect(branch_output, branch_processor->getInputs().front());
new_processors.push_back(branch_processor);
}
}
return new_processors;
};
Expand Down Expand Up @@ -278,6 +209,14 @@ void UniteBranchesStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, c
{
DB::Processors new_processors;
size_t branch_index = 0;
if (child_outputs.size() != branch_plans.size())
{
throw DB::Exception(
DB::ErrorCodes::LOGICAL_ERROR,
"Output port's size({}) is not equal to branches size({})",
child_outputs.size(),
branch_plans.size());
}
for (auto output : child_outputs)
{
auto & branch_plan = branch_plans[branch_index];
Expand Down
2 changes: 2 additions & 0 deletions cpp-ch/local-engine/Operator/BranchStep.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class StaticBranchStep : public DB::ITransformingStep
BranchSelector selector;
};


// It should be better to build execution branches on QueryPlan.
class UniteBranchesStep : public DB::ITransformingStep
{
public:
Expand Down
Loading

0 comments on commit 73f7b6d

Please sign in to comment.