Skip to content

Commit

Permalink
use window on high cardinality keys
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Nov 25, 2024
1 parent 2a80d3c commit 0bd7d4f
Show file tree
Hide file tree
Showing 15 changed files with 974 additions and 190 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
)
}

// If the partition keys are high cardinality, the aggregation method is slower.
def enableConvertWindowGroupLimitToAggregate(): Boolean = {
SparkEnv.get.conf.getBoolean(
CHConf.runtimeConfig("enable_window_group_limit_to_aggregate"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3162,62 +3162,66 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
}

test("GLUTEN-7905 get topk of window by aggregate") {
def checkWindowGroupLimit(df: DataFrame): Unit = {
val expands = collectWithSubqueries(df.queryExecution.executedPlan) {
case e: ExpandExecTransformer
if (e.child.isInstanceOf[CHAggregateGroupLimitExecTransformer]) =>
e
withSQLConf((
"spark.gluten.sql.columnar.backend.ch.runtime_config.enable_window_group_limit_to_aggregate",
"true")) {
def checkWindowGroupLimit(df: DataFrame): Unit = {
val expands = collectWithSubqueries(df.queryExecution.executedPlan) {
case e: ExpandExecTransformer
if (e.child.isInstanceOf[CHAggregateGroupLimitExecTransformer]) =>
e
}
assert(expands.size == 1)
}
assert(expands.size == 1)
spark.sql("create table test_win_top (a string, b int, c int) using parquet")
spark.sql("""
|insert into test_win_top values
|('a', 3, 3), ('a', 1, 5), ('a', 2, 2), ('a', null, null), ('a', null, 1),
|('b', 1, 1), ('b', 2, 1),
|('c', 2, 3)
|""".stripMargin)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b desc nulls first) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b desc, c nulls last) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b asc nulls first, c) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b asc nulls last) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b , c) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
spark.sql("drop table if exists test_win_top")
}
spark.sql("create table test_win_top (a string, b int, c int) using parquet")
spark.sql("""
|insert into test_win_top values
|('a', 3, 3), ('a', 1, 5), ('a', 2, 2), ('a', null, null), ('a', null, 1),
|('b', 1, 1), ('b', 2, 1),
|('c', 2, 3)
|""".stripMargin)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b desc nulls first) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b desc, c nulls last) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b asc nulls first) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b asc nulls last) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
compareResultsAgainstVanillaSpark(
"""
|select a, b, c, row_number() over (partition by a order by b , c) as r
|from test_win_top
|""".stripMargin,
true,
checkWindowGroupLimit
)
spark.sql("drop table if exists test_win_top")

}

Expand Down
29 changes: 6 additions & 23 deletions cpp-ch/local-engine/AggregateFunctions/GroupLimitFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

#include <Poco/Logger.h>
#include <Common/logger_useful.h>
#include "base/defines.h"

namespace DB::ErrorCodes
{
Expand Down Expand Up @@ -72,7 +73,6 @@ struct RowNumGroupArraySortedData
const auto & pos = sort_order.pos;
const auto & asc = sort_order.direction;
const auto & nulls_first = sort_order.nulls_direction;
LOG_ERROR(getLogger("GroupLimitFunction"), "xxx pos: {} tuple size: {} {}", pos, rhs.size(), lhs.size());
bool l_is_null = lhs[pos].isNull();
bool r_is_null = rhs[pos].isNull();
if (l_is_null && r_is_null)
Expand Down Expand Up @@ -120,25 +120,17 @@ struct RowNumGroupArraySortedData
values[current_index] = current;
}

ALWAYS_INLINE void addElement(const Data & data, const SortOrderFields & sort_orders, size_t max_elements)
ALWAYS_INLINE void addElement(const Data && data, const SortOrderFields & sort_orders, size_t max_elements)
{
if (values.size() >= max_elements)
{
LOG_ERROR(
getLogger("GroupLimitFunction"),
"xxxx values size: {}, limit: {}, tuple size: {} {}",
values.size(),
max_elements,
data.size(),
values[0].size());
if (!compare(data, values[0], sort_orders))
return;
values[0] = data;
heapReplaceTop(sort_orders);
return;
}
values.push_back(data);
LOG_ERROR(getLogger("GroupLimitFunction"), "add new element: {} {}", values.size(), values.back().size());
values.emplace_back(std::move(data));
auto cmp = [&sort_orders](const Data & a, const Data & b) { return compare(a, b, sort_orders); };
std::push_heap(values.begin(), values.end(), cmp);
}
Expand Down Expand Up @@ -190,7 +182,7 @@ class RowNumGroupArraySorted final : public DB::IAggregateFunctionDataHelper<Row
public:
explicit RowNumGroupArraySorted(DB::DataTypePtr data_type, const DB::Array & parameters_)
: DB::IAggregateFunctionDataHelper<RowNumGroupArraySortedData, RowNumGroupArraySorted>(
{data_type}, parameters_, getRowNumReultDataType(data_type))
{data_type}, parameters_, getRowNumReultDataType(data_type))
{
if (parameters_.size() != 2)
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "{} needs two parameters: limit and order clause", getName());
Expand All @@ -212,23 +204,14 @@ class RowNumGroupArraySorted final : public DB::IAggregateFunctionDataHelper<Row
{
auto & data = this->data(place);
DB::Tuple data_tuple = (*columns[0])[row_num].safeGet<DB::Tuple>();
// const DB::Tuple & data_tuple = *(static_cast<const DB::Tuple *>(&((*columns[0])[row_num])));
LOG_ERROR(
getLogger("GroupLimitFunction"),
"xxx col len: {}, row num: {}, tuple size: {}, type: {}",
columns[0]->size(),
row_num,
data_tuple.size(),
(*columns[0])[row_num].getType());
;
this->data(place).addElement(data_tuple, sort_order_fields, limit);
this->data(place).addElement(std::move(data_tuple), sort_order_fields, limit);
}

void merge(DB::AggregateDataPtr __restrict place, DB::ConstAggregateDataPtr rhs, DB::Arena * /*arena*/) const override
{
auto & rhs_values = this->data(rhs).values;
for (auto & rhs_element : rhs_values)
this->data(place).addElement(rhs_element, sort_order_fields, limit);
this->data(place).addElement(std::move(rhs_element), sort_order_fields, limit);
}

void serialize(DB::ConstAggregateDataPtr __restrict place, DB::WriteBuffer & buf, std::optional<size_t> /* version */) const override
Expand Down
1 change: 0 additions & 1 deletion cpp-ch/local-engine/Common/AggregateUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ extern const SettingsBool enable_memory_bound_merging_of_aggregation_results;
extern const SettingsUInt64 aggregation_in_order_max_block_bytes;
extern const SettingsUInt64 group_by_two_level_threshold;
extern const SettingsFloat min_hit_rate_to_use_consecutive_keys_optimization;
extern const SettingsMaxThreads max_threads;
extern const SettingsUInt64 max_block_size;
}

Expand Down
15 changes: 15 additions & 0 deletions cpp-ch/local-engine/Common/ArrayJoinHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,21 @@ addArrayJoinStep(DB::ContextPtr context, DB::QueryPlan & plan, const DB::Actions
steps.emplace_back(array_join_step.get());
plan.addStep(std::move(array_join_step));
// LOG_DEBUG(logger, "plan2:{}", PlanUtil::explainPlan(*query_plan));

/// Post-projection after array join(Optional)
if (!ignore_actions_dag(splitted_actions_dags.after_array_join))
{
auto step_after_array_join
= std::make_unique<DB::ExpressionStep>(plan.getCurrentHeader(), std::move(splitted_actions_dags.after_array_join));
step_after_array_join->setStepDescription("Post-projection In Generate");
steps.emplace_back(step_after_array_join.get());
plan.addStep(std::move(step_after_array_join));
// LOG_DEBUG(logger, "plan3:{}", PlanUtil::explainPlan(*query_plan));
}
}
else
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Expect array join node in actions_dag");
}

return steps;
Expand Down
11 changes: 10 additions & 1 deletion cpp-ch/local-engine/Common/GlutenConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,13 @@ MergeTreeCacheConfig MergeTreeCacheConfig::loadFromContext(const DB::ContextPtr
config.enable_data_prefetch = context->getConfigRef().getBool(ENABLE_DATA_PREFETCH, config.enable_data_prefetch);
return config;
}
}

WindowConfig WindowConfig::loadFromContext(const DB::ContextPtr & context)
{
WindowConfig config;
config.aggregate_topk_sample_rows = context->getConfigRef().getUInt64(WINDOW_AGGREGATE_TOPK_SAMPLE_ROWS, 5000);
config.aggregate_topk_high_cardinality_threshold
= context->getConfigRef().getDouble(WINDOW_AGGREGATE_TOPK_HIGH_CARDINALITY_THRESHOLD, 0.6);
return config;
}
}
22 changes: 18 additions & 4 deletions cpp-ch/local-engine/Common/GlutenConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,12 @@ struct GraceMergingAggregateConfig
{
inline static const String MAX_GRACE_AGGREGATE_MERGING_BUCKETS = "max_grace_aggregate_merging_buckets";
inline static const String THROW_ON_OVERFLOW_GRACE_AGGREGATE_MERGING_BUCKETS = "throw_on_overflow_grace_aggregate_merging_buckets";
inline static const String AGGREGATED_KEYS_BEFORE_EXTEND_GRACE_AGGREGATE_MERGING_BUCKETS = "aggregated_keys_before_extend_grace_aggregate_merging_buckets";
inline static const String MAX_PENDING_FLUSH_BLOCKS_PER_GRACE_AGGREGATE_MERGING_BUCKET = "max_pending_flush_blocks_per_grace_aggregate_merging_bucket";
inline static const String MAX_ALLOWED_MEMORY_USAGE_RATIO_FOR_AGGREGATE_MERGING = "max_allowed_memory_usage_ratio_for_aggregate_merging";
inline static const String AGGREGATED_KEYS_BEFORE_EXTEND_GRACE_AGGREGATE_MERGING_BUCKETS
= "aggregated_keys_before_extend_grace_aggregate_merging_buckets";
inline static const String MAX_PENDING_FLUSH_BLOCKS_PER_GRACE_AGGREGATE_MERGING_BUCKET
= "max_pending_flush_blocks_per_grace_aggregate_merging_bucket";
inline static const String MAX_ALLOWED_MEMORY_USAGE_RATIO_FOR_AGGREGATE_MERGING
= "max_allowed_memory_usage_ratio_for_aggregate_merging";

size_t max_grace_aggregate_merging_buckets = 32;
bool throw_on_overflow_grace_aggregate_merging_buckets = false;
Expand All @@ -73,7 +76,8 @@ struct StreamingAggregateConfig
{
inline static const String AGGREGATED_KEYS_BEFORE_STREAMING_AGGREGATING_EVICT = "aggregated_keys_before_streaming_aggregating_evict";
inline static const String MAX_MEMORY_USAGE_RATIO_FOR_STREAMING_AGGREGATING = "max_memory_usage_ratio_for_streaming_aggregating";
inline static const String HIGH_CARDINALITY_THRESHOLD_FOR_STREAMING_AGGREGATING = "high_cardinality_threshold_for_streaming_aggregating";
inline static const String HIGH_CARDINALITY_THRESHOLD_FOR_STREAMING_AGGREGATING
= "high_cardinality_threshold_for_streaming_aggregating";
inline static const String ENABLE_STREAMING_AGGREGATING = "enable_streaming_aggregating";

size_t aggregated_keys_before_streaming_aggregating_evict = 1024;
Expand Down Expand Up @@ -154,6 +158,16 @@ struct MergeTreeCacheConfig
static MergeTreeCacheConfig loadFromContext(const DB::ContextPtr & context);
};

struct WindowConfig
{
public:
inline static const String WINDOW_AGGREGATE_TOPK_SAMPLE_ROWS = "window.aggregate_topk_sample_rows";
inline static const String WINDOW_AGGREGATE_TOPK_HIGH_CARDINALITY_THRESHOLD = "window.aggregate_topk_high_cardinality_threshold";
size_t aggregate_topk_sample_rows = 5000;
double aggregate_topk_high_cardinality_threshold = 0.6;
static WindowConfig loadFromContext(const DB::ContextPtr & context);
};

namespace PathConfig
{
inline constexpr const char * USE_CURRENT_DIRECTORY_AS_TMP = "use_current_directory_as_tmp";
Expand Down
Loading

0 comments on commit 0bd7d4f

Please sign in to comment.