Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-7087][CH] Support WindowGroupLimitExec #7176

Merged
merged 3 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -418,4 +418,6 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
}
}
}

override def supportWindowGroupLimitExec(rankLikeFunction: Expression): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx
import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildSideRelation, HashedRelationBroadcastMode}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.utils.{CHExecUtil, PushDownUtil}
import org.apache.spark.sql.execution.window._
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand Down Expand Up @@ -909,4 +910,19 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
toScale: Int): DecimalType = {
SparkShimLoader.getSparkShims.genDecimalRoundExpressionOutput(decimalType, toScale)
}

override def genWindowGroupLimitTransformer(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
rankLikeFunction: Expression,
limit: Int,
mode: WindowGroupLimitMode,
child: SparkPlan): SparkPlan =
CHWindowGroupLimitExecTransformer(
partitionSpec,
orderSpec,
rankLikeFunction,
limit,
mode,
child)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression._
import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.`type`.TypeBuilder
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.extensions.ExtensionBuilder
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.window.{Final, Partial, WindowGroupLimitMode}

import com.google.protobuf.StringValue
import io.substrait.proto.SortField

import scala.collection.JavaConverters._

case class CHWindowGroupLimitExecTransformer(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
rankLikeFunction: Expression,
limit: Int,
mode: WindowGroupLimitMode,
child: SparkPlan)
extends UnaryTransformSupport {

@transient override lazy val metrics =
BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetrics(sparkContext)

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

override def metricsUpdater(): MetricsUpdater =
BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetricsUpdater(metrics)

override def output: Seq[Attribute] = child.output

override def requiredChildDistribution: Seq[Distribution] = mode match {
case Partial => super.requiredChildDistribution
case Final =>
if (partitionSpec.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(partitionSpec) :: Nil
}
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
if (BackendsApiManager.getSettings.requiredChildOrderingForWindowGroupLimit()) {
Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
} else {
Seq(Nil)
}
}

override def outputOrdering: Seq[SortOrder] = {
if (requiredChildOrdering.forall(_.isEmpty)) {
// The Velox backend `TopNRowNumber` does not require child ordering, because it
// uses hash table to store partition and use priority queue to track of top limit rows.
// Ideally, the output of `TopNRowNumber` is unordered but it is grouped for partition keys.
// To be safe, here we do not propagate the ordering.
// TODO: Make the framework aware of grouped data distribution
Nil
} else {
child.outputOrdering
}
}

override def outputPartitioning: Partitioning = child.outputPartitioning

def getWindowGroupLimitRel(
context: SubstraitContext,
originalInputAttributes: Seq[Attribute],
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
val args = context.registeredFunction
// Partition By Expressions
val partitionsExpressions = partitionSpec
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, attributeSeq = child.output)
.doTransform(args))
.asJava

// Sort By Expressions
val sortFieldList =
orderSpec.map {
order =>
val builder = SortField.newBuilder()
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(order.child, attributeSeq = child.output)
.doTransform(args)
builder.setExpr(exprNode.toProtobuf)
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
builder.build()
}.asJava
if (!validation) {
val windowFunction = rankLikeFunction match {
case _: RowNumber => ExpressionNames.ROW_NUMBER
case _: Rank => ExpressionNames.RANK
case _: DenseRank => ExpressionNames.DENSE_RANK
case _ => throw new GlutenNotSupportException(s"Unknow window function $rankLikeFunction")
}
val parametersStr = new StringBuffer("WindowGroupLimitParameters:")
parametersStr
.append("window_function=")
.append(windowFunction)
.append("\n")
val message = StringValue.newBuilder().setValue(parametersStr.toString).build()
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(message),
null)
RelBuilder.makeWindowGroupLimitRel(
input,
partitionsExpressions,
sortFieldList,
limit,
extensionNode,
context,
operatorId)
} else {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))

RelBuilder.makeWindowGroupLimitRel(
input,
partitionsExpressions,
sortFieldList,
limit,
extensionNode,
context,
operatorId)
}
}

override protected def doValidateInternal(): ValidationResult = {
if (!BackendsApiManager.getSettings.supportWindowGroupLimitExec(rankLikeFunction)) {
return ValidationResult
.failed(s"Found unsupported rank like function: $rankLikeFunction")
}
val substraitContext = new SubstraitContext
val operatorId = substraitContext.nextOperatorId(this.nodeName)

val relNode =
getWindowGroupLimitRel(substraitContext, child.output, operatorId, null, validation = true)

doNativeValidation(substraitContext, relNode)
}

override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)

val currRel =
getWindowGroupLimitRel(context, child.output, operatorId, childCtx.root, validation = false)
assert(currRel != null, "Window Group Limit Rel should be valid")
TransformContext(childCtx.outputAttributes, output, currRel)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
})

protected def fallbackSets(isAqe: Boolean): Set[Int] = {
if (isSparkVersionGE("3.5")) Set(44, 67, 70) else Set.empty[Int]
Set.empty[Int]
}
protected def excludedTpcdsQueries: Set[String] = Set(
"q66" // inconsistent results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1855,7 +1855,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
| ) t1
|) t2 where rank = 1
""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => }, isSparkVersionLE("3.3"))
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}

test("GLUTEN-1874 not null in both streams") {
Expand All @@ -1873,7 +1873,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
| ) t1
|) t2 where rank = 1
""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => }, isSparkVersionLE("3.3"))
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}

test("GLUTEN-2095: test cast(string as binary)") {
Expand Down Expand Up @@ -2456,7 +2456,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
| ) t1
|) t2 where rank = 1 order by p_partkey limit 100
|""".stripMargin
runQueryAndCompare(sql, noFallBack = isSparkVersionLE("3.3"))({ _ => })
runQueryAndCompare(sql, noFallBack = true)({ _ => })
}

test("GLUTEN-4190: crush on flattening a const null column") {
Expand Down
41 changes: 22 additions & 19 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include <Parser/SubstraitParserUtils.h>
#include <Planner/PlannerActionsVisitor.h>
#include <Processors/Chunk.h>
#include <Processors/Formats/IOutputFormat.h>
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
Expand Down Expand Up @@ -315,7 +316,6 @@ DB::Block BlockUtil::concatenateBlocksMemoryEfficiently(std::vector<DB::Block> &
return out;
}


size_t PODArrayUtil::adjustMemoryEfficientSize(size_t n)
{
/// According to definition of DEFUALT_BLOCK_SIZE
Expand Down Expand Up @@ -560,9 +560,7 @@ std::map<std::string, std::string> BackendInitializerUtil::getBackendConfMap(std
}

std::vector<String> BackendInitializerUtil::wrapDiskPathConfig(
const String & path_prefix,
const String & path_suffix,
Poco::Util::AbstractConfiguration & config)
const String & path_prefix, const String & path_suffix, Poco::Util::AbstractConfiguration & config)
{
std::vector<String> changed_paths;
if (path_prefix.empty() && path_suffix.empty())
Expand Down Expand Up @@ -657,9 +655,7 @@ DB::Context::ConfigurationPtr BackendInitializerUtil::initConfig(std::map<std::s
auto path_need_clean = wrapDiskPathConfig("", "/" + pid, *config);
std::lock_guard lock(BackendFinalizerUtil::paths_mutex);
BackendFinalizerUtil::paths_need_to_clean.insert(
BackendFinalizerUtil::paths_need_to_clean.end(),
path_need_clean.begin(),
path_need_clean.end());
BackendFinalizerUtil::paths_need_to_clean.end(), path_need_clean.begin(), path_need_clean.end());
}
return config;
}
Expand All @@ -683,7 +679,9 @@ void BackendInitializerUtil::initEnvs(DB::Context::ConfigurationPtr config)
{
const std::string config_timezone = config->getString("timezone");
const String mapped_timezone = DateTimeUtil::convertTimeZone(config_timezone);
if (0 != setenv("TZ", mapped_timezone.data(), 1)) // NOLINT(concurrency-mt-unsafe) // ok if not called concurrently with other setenv/getenv
if (0
!= setenv(
"TZ", mapped_timezone.data(), 1)) // NOLINT(concurrency-mt-unsafe) // ok if not called concurrently with other setenv/getenv
throw Poco::Exception("Cannot setenv TZ variable");

tzset();
Expand Down Expand Up @@ -807,8 +805,7 @@ void BackendInitializerUtil::initSettings(std::map<std::string, std::string> & b
{
auto mem_gb = task_memory / static_cast<double>(1_GiB);
// 2.8x+5, Heuristics calculate the block size of external sort, [8,16]
settings.prefer_external_sort_block_bytes = std::max(std::min(
static_cast<size_t>(2.8*mem_gb + 5), 16ul), 8ul) * 1024 * 1024;
settings.prefer_external_sort_block_bytes = std::max(std::min(static_cast<size_t>(2.8 * mem_gb + 5), 16ul), 8ul) * 1024 * 1024;
}
}
}
Expand Down Expand Up @@ -848,10 +845,14 @@ void BackendInitializerUtil::initContexts(DB::Context::ConfigurationPtr config)

global_context->setMarkCache(mark_cache_policy, mark_cache_size, mark_cache_size_ratio);

String index_uncompressed_cache_policy = config->getString("index_uncompressed_cache_policy", DEFAULT_INDEX_UNCOMPRESSED_CACHE_POLICY);
size_t index_uncompressed_cache_size = config->getUInt64("index_uncompressed_cache_size", DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE);
double index_uncompressed_cache_size_ratio = config->getDouble("index_uncompressed_cache_size_ratio", DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO);
global_context->setIndexUncompressedCache(index_uncompressed_cache_policy, index_uncompressed_cache_size, index_uncompressed_cache_size_ratio);
String index_uncompressed_cache_policy
= config->getString("index_uncompressed_cache_policy", DEFAULT_INDEX_UNCOMPRESSED_CACHE_POLICY);
size_t index_uncompressed_cache_size
= config->getUInt64("index_uncompressed_cache_size", DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE);
double index_uncompressed_cache_size_ratio
= config->getDouble("index_uncompressed_cache_size_ratio", DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO);
global_context->setIndexUncompressedCache(
index_uncompressed_cache_policy, index_uncompressed_cache_size, index_uncompressed_cache_size_ratio);

String index_mark_cache_policy = config->getString("index_mark_cache_policy", DEFAULT_INDEX_MARK_CACHE_POLICY);
size_t index_mark_cache_size = config->getUInt64("index_mark_cache_size", DEFAULT_INDEX_MARK_CACHE_MAX_SIZE);
Expand Down Expand Up @@ -1023,11 +1024,13 @@ void BackendFinalizerUtil::finalizeGlobally()
StorageMergeTreeFactory::clear();
QueryContext::resetGlobal();
std::lock_guard lock(paths_mutex);
std::ranges::for_each(paths_need_to_clean, [](const auto & path)
{
if (fs::exists(path))
fs::remove_all(path);
});
std::ranges::for_each(
paths_need_to_clean,
[](const auto & path)
{
if (fs::exists(path))
fs::remove_all(path);
});
paths_need_to_clean.clear();
}

Expand Down
Loading
Loading