Skip to content

Commit

Permalink
[Enhancement] adjust agg pushdown strategy for broadcast (StarRocks#5…
Browse files Browse the repository at this point in the history
…4572)

Signed-off-by: stephen <[email protected]>
  • Loading branch information
stephen-shelby authored and magzhu committed Jan 6, 2025
1 parent e0a8787 commit 24b23d4
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 26 deletions.
1 change: 1 addition & 0 deletions be/src/exec/aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,7 @@ void Aggregator::_init_agg_hash_variant(HashVariantType& hash_variant) {
}
}
}

VLOG_ROW << "hash type is "
<< static_cast<typename std::underlying_type<typename HashVariantType::Type>::type>(type);
hash_variant.init(_state, type, _agg_stat);
Expand Down
13 changes: 13 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ public class SessionVariable implements Serializable, Writable, Cloneable {
public static final String CBO_PRUNE_SHUFFLE_COLUMN_RATE = "cbo_prune_shuffle_column_rate";
public static final String CBO_PUSH_DOWN_AGGREGATE_MODE = "cbo_push_down_aggregate_mode";
public static final String CBO_PUSH_DOWN_AGGREGATE_ON_BROADCAST_JOIN = "cbo_push_down_aggregate_on_broadcast_join";
public static final String CBO_PUSH_DOWN_AGGREGATE_ON_BROADCAST_JOIN_ROW_COUNT_LIMIT =
"cbo_push_down_aggregate_on_broadcast_join_row_count_limit";

public static final String CBO_PUSH_DOWN_DISTINCT_BELOW_WINDOW = "cbo_push_down_distinct_below_window";
public static final String CBO_PUSH_DOWN_AGGREGATE = "cbo_push_down_aggregate";
Expand Down Expand Up @@ -1546,6 +1548,9 @@ public static MaterializedViewRewriteMode parse(String str) {
@VarAttr(name = CBO_PUSH_DOWN_AGGREGATE_ON_BROADCAST_JOIN, flag = VariableMgr.INVISIBLE)
private boolean cboPushDownAggregateOnBroadcastJoin = true;

@VarAttr(name = CBO_PUSH_DOWN_AGGREGATE_ON_BROADCAST_JOIN_ROW_COUNT_LIMIT, flag = VariableMgr.INVISIBLE)
private long cboPushDownAggregateOnBroadcastJoinRowCountLimit = 250000;

// auto, global, local
@VarAttr(name = CBO_PUSH_DOWN_AGGREGATE, flag = VariableMgr.INVISIBLE)
private String cboPushDownAggregate = "global";
Expand Down Expand Up @@ -3621,6 +3626,14 @@ public void setCboPushDownAggregateOnBroadcastJoin(boolean cboPushDownAggregateO
this.cboPushDownAggregateOnBroadcastJoin = cboPushDownAggregateOnBroadcastJoin;
}

public long getCboPushDownAggregateOnBroadcastJoinRowCountLimit() {
return cboPushDownAggregateOnBroadcastJoinRowCountLimit;
}

public void setCboPushDownAggregateOnBroadcastJoinRowCountLimit(long cboPushDownAggregateOnBroadcastJoinRowCountLimit) {
this.cboPushDownAggregateOnBroadcastJoinRowCountLimit = cboPushDownAggregateOnBroadcastJoinRowCountLimit;
}

public String getCboPushDownAggregate() {
return cboPushDownAggregate;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
import java.util.Set;
import java.util.stream.Collectors;

import static com.starrocks.sql.optimizer.statistics.StatisticsEstimateCoefficient.SMALL_BROADCAST_JOIN_MAX_NDV_LIMIT;

/*
* Collect all can be push down aggregate context, to get which aggregation can be
* pushed down and the push down path.
Expand Down Expand Up @@ -473,13 +475,23 @@ private boolean checkStatistics(AggregatePushDownContext context, ColumnRefSet g

List<ColumnStatistic>[] cards = new List[] {lower, medium, high};

groupBys.getStream().map(factory::getColumnRef)
Set<ColumnStatistic> columnStatistics = groupBys.getStream()
.map(factory::getColumnRef)
.map(s -> ExpressionStatisticCalculator.calculate(s, statistics))
.forEach(s -> cards[groupByCardinality(s, statistics.getOutputRowCount())].add(s));
.collect(Collectors.toSet());
columnStatistics.forEach(s -> cards[groupByCardinality(s, statistics.getOutputRowCount())].add(s));

double lowerCartesian = lower.stream().map(ColumnStatistic::getDistinctValuesCount).reduce((a, b) -> a * b)
.orElse(Double.MAX_VALUE);

// target is the immediate child of a small broadcast join
// and the ndv of all columns is less than SMALL_BROADCAST_JOIN_MAX_NDV_LIMIT
if (pushDownMode == PUSH_DOWN_AGG_AUTO && context.immediateChildOfSmallBroadcastJoin) {
if (columnStatistics.stream().anyMatch(x -> x.getDistinctValuesCount() > SMALL_BROADCAST_JOIN_MAX_NDV_LIMIT)) {
return false;
}
}

// pow(row_count/20, a half of lower column size)
double lowerUpper = Math.max(statistics.getOutputRowCount() / 20, 1);
lowerUpper = Math.pow(lowerUpper, Math.max(lower.size() / 2, 1));
Expand Down Expand Up @@ -516,15 +528,9 @@ private boolean checkStatistics(AggregatePushDownContext context, ColumnRefSet g
}
}

// 2. forbidden rules
// 2.1 target is the immediate child of a small broadcast join and the cardinality of the aggregation is not lower.
if (pushDownMode == PUSH_DOWN_AGG_AUTO && context.immediateChildOfSmallBroadcastJoin) {
return false;
}

// 2.2 high cardinality >= 2
// 2.3 medium cardinality > 2
// 2.4 high cardinality = 1 and medium cardinality > 0
// 2.1 high cardinality >= 2
// 2.2 medium cardinality > 2
// 2.3 high cardinality = 1 and medium cardinality > 0
if (high.size() >= 2 || medium.size() > 2 || (high.size() == 1 && !medium.isEmpty())) {
return false;
}
Expand Down Expand Up @@ -553,9 +559,9 @@ private boolean checkStatistics(AggregatePushDownContext context, ColumnRefSet g
return false;
}

// high(2): cardinality/count > MEDIUM_AGGREGATE
// medium(1): cardinality/count <= MEDIUM_AGGREGATE and > LOW_AGGREGATE
// lower(0): cardinality/count < LOW_AGGREGATE
// high(2): row_count / cardinality < MEDIUM_AGGREGATE_EFFECT_COEFFICIENT
// medium(1): row_count / cardinality >= MEDIUM_AGGREGATE_EFFECT_COEFFICIENT and < LOW_AGGREGATE_EFFECT_COEFFICIENT
// lower(0): row_count / cardinality >= LOW_AGGREGATE_EFFECT_COEFFICIENT
private int groupByCardinality(ColumnStatistic statistic, double rowCount) {
if (statistic.isUnknown()) {
return 2;
Expand Down Expand Up @@ -586,7 +592,7 @@ private boolean isSmallBroadcastJoin(OptExpression optExpression) {
}
double rightRows = rightStatistics.getOutputRowCount();
return rightRows <= sessionVariable.getBroadcastRowCountLimit() &&
rightRows <= StatisticsEstimateCoefficient.SMALL_BROADCAST_JOIN_ROW_COUNT_UPPER_BOUND;
rightRows <= sessionVariable.getCboPushDownAggregateOnBroadcastJoinRowCountLimit();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public class StatisticsEstimateCoefficient {
public static final double LOWER_AGGREGATE_EFFECT_COEFFICIENT = 10000;
public static final double LOW_AGGREGATE_EFFECT_COEFFICIENT = 1000;
public static final double MEDIUM_AGGREGATE_EFFECT_COEFFICIENT = 100;
public static final int SMALL_BROADCAST_JOIN_ROW_COUNT_UPPER_BOUND = 4096;
public static final int SMALL_BROADCAST_JOIN_MAX_NDV_LIMIT = 100000;

public static final double EXTREME_HIGH_AGGREGATE_EFFECT_COEFFICIENT = 3;
// default selectivity for anti join
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,36 +148,36 @@ private static Stream<Arguments> testPushDownProvider() {
Arguments[] cases = new Arguments[] {
Arguments.of("Q01", 4, 4, false, 6, true, 4, false, 6, true),
Arguments.of("Q02", 2, 6, true, 6, true, 6, true, 6, true),
Arguments.of("Q03", 2, 4, true, 4, true, 4, true, 4, true),
Arguments.of("Q03", 2, 2, false, 4, true, 4, true, 4, true),
// Although the number of aggregators is the same, the aggregator was pushed down.
// This is caused by the CTE. orig: CTE inline, auto~high: CTE
Arguments.of("Q04", 12, 12, true, 12, true, 12, true, 12, true),
Arguments.of("Q05", 8, 16, true, 16, true, 16, true, 16, true),
Arguments.of("Q08", 4, 6, true, 6, true, 6, true, 6, true),
Arguments.of("Q11", 8, 8, true, 8, true, 8, true, 8, true),
Arguments.of("Q12", 2, 4, true, 4, true, 4, true, 4, true),
Arguments.of("Q12", 2, 2, false, 4, true, 4, true, 4, true),
Arguments.of("Q15", 2, 2, false, 4, true, 4, true, 4, true),
Arguments.of("Q19", 2, 2, false, 4, true, 2, false, 2, false),
Arguments.of("Q20", 2, 4, true, 4, true, 4, true, 4, true),
Arguments.of("Q20", 2, 2, false, 4, true, 4, true, 4, true),
Arguments.of("Q23_1", 10, 13, true, 13, true, 13, true, 13, true),
Arguments.of("Q24_1", 6, 6, false, 7, true, 6, false, 6, false),
Arguments.of("Q24_2", 6, 6, false, 7, true, 6, false, 6, false),
Arguments.of("Q30", 4, 4, false, 6, true, 4, false, 4, false),
Arguments.of("Q31", 4, 8, true, 8, true, 8, true, 8, true),
Arguments.of("Q33", 8, 8, false, 14, true, 14, true, 14, true),
Arguments.of("Q37", 2, 4, true, 8, true, 6, true, 7, true),
Arguments.of("Q37", 2, 2, false, 8, true, 6, true, 7, true),
Arguments.of("Q38", 8, 14, true, 20, true, 14, true, 17, true),
Arguments.of("Q41", 4, 4, false, 6, true, 4, false, 4, false),
Arguments.of("Q42", 2, 4, true, 4, true, 4, true, 4, true),
Arguments.of("Q43", 2, 4, true, 4, true, 4, true, 4, true),
Arguments.of("Q45", 6, 6, false, 8, true, 6, false, 8, true),
Arguments.of("Q46", 2, 2, false, 4, true, 2, false, 2, false),
Arguments.of("Q47", 2, 2, true, 4, true, 4, true, 4, true),
Arguments.of("Q51", 4, 8, true, 8, true, 8, true, 8, true),
Arguments.of("Q52", 2, 4, true, 4, true, 4, true, 4, true),
Arguments.of("Q51", 4, 4, false, 8, true, 8, true, 8, true),
Arguments.of("Q52", 2, 2, false, 4, true, 4, true, 4, true),
Arguments.of("Q53", 2, 2, false, 4, true, 4, true, 4, true),
Arguments.of("Q54", 9, 11, true, 18, true, 11, true, 17, true),
Arguments.of("Q55", 2, 4, true, 4, true, 4, true, 4, true),
Arguments.of("Q54", 9, 9, false, 18, true, 11, true, 17, true),
Arguments.of("Q55", 2, 2, false, 4, true, 4, true, 4, true),
Arguments.of("Q56", 8, 8, false, 14, true, 14, true, 14, true),
Arguments.of("Q57", 2, 2, true, 4, true, 4, true, 4, true),
Arguments.of("Q58", 6, 12, true, 12, true, 12, true, 12, true),
Expand All @@ -194,13 +194,13 @@ private static Stream<Arguments> testPushDownProvider() {
Arguments.of("Q78", 6, 6, false, 9, true, 6, false, 6, false),
Arguments.of("Q79", 2, 2, false, 4, true, 2, false, 2, false),
Arguments.of("Q81", 4, 4, false, 6, true, 4, false, 4, false),
Arguments.of("Q82", 2, 4, true, 8, true, 6, true, 7, true),
Arguments.of("Q82", 2, 2, false, 8, true, 6, true, 7, true),
Arguments.of("Q83", 6, 12, true, 12, true, 12, true, 12, true),
Arguments.of("Q87", 8, 14, true, 20, true, 14, true, 17, true),
Arguments.of("Q89", 2, 2, false, 4, true, 4, true, 4, true),
Arguments.of("Q91", 2, 4, true, 4, true, 4, true, 4, true),
Arguments.of("Q97", 6, 6, false, 12, true, 10, true, 12, true),
Arguments.of("Q98", 2, 4, true, 4, true, 4, true, 4, true),
Arguments.of("Q98", 2, 2, false, 4, true, 4, true, 4, true),
};

return Arrays.stream(cases);
Expand Down

0 comments on commit 24b23d4

Please sign in to comment.