diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 69ea899c42a5..45aee4322611 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -418,4 +418,6 @@ object CHBackendSettings extends BackendSettingsApi with Logging { } } } + + override def supportWindowGroupLimitExec(rankLikeFunction: Expression): Boolean = true } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 1108b8b3c501..f765a75d2f7d 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildSideRelation, HashedRelationBroadcastMode} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.utils.{CHExecUtil, PushDownUtil} +import org.apache.spark.sql.execution.window._ import org.apache.spark.sql.types.{DecimalType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -909,4 +910,19 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { toScale: Int): DecimalType = { SparkShimLoader.getSparkShims.genDecimalRoundExpressionOutput(decimalType, toScale) } + + override def genWindowGroupLimitTransformer( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + rankLikeFunction: Expression, + limit: Int, + mode: WindowGroupLimitMode, + child: SparkPlan): SparkPlan = + CHWindowGroupLimitExecTransformer( + partitionSpec, + orderSpec, + rankLikeFunction, + limit, + mode, + child) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala new file mode 100644 index 000000000000..c2648f29ec4c --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.exception.GlutenNotSupportException +import org.apache.gluten.expression._ +import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter} +import org.apache.gluten.extension.ValidationResult +import org.apache.gluten.metrics.MetricsUpdater +import org.apache.gluten.substrait.`type`.TypeBuilder +import org.apache.gluten.substrait.SubstraitContext +import org.apache.gluten.substrait.extensions.ExtensionBuilder +import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.window.{Final, Partial, WindowGroupLimitMode} + +import com.google.protobuf.StringValue +import io.substrait.proto.SortField + +import scala.collection.JavaConverters._ + +case class CHWindowGroupLimitExecTransformer( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + rankLikeFunction: Expression, + limit: Int, + mode: WindowGroupLimitMode, + child: SparkPlan) + extends UnaryTransformSupport { + + @transient override lazy val metrics = + BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetrics(sparkContext) + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) + + override def metricsUpdater(): MetricsUpdater = + BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetricsUpdater(metrics) + + override def output: Seq[Attribute] = child.output + + override def requiredChildDistribution: Seq[Distribution] = mode match { + case Partial => super.requiredChildDistribution + case Final => + if (partitionSpec.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(partitionSpec) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + if (BackendsApiManager.getSettings.requiredChildOrderingForWindowGroupLimit()) { + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + } else { + Seq(Nil) + } + } + + override def outputOrdering: Seq[SortOrder] = { + if (requiredChildOrdering.forall(_.isEmpty)) { + // The Velox backend `TopNRowNumber` does not require child ordering, because it + // uses hash table to store partition and use priority queue to track of top limit rows. + // Ideally, the output of `TopNRowNumber` is unordered but it is grouped for partition keys. + // To be safe, here we do not propagate the ordering. + // TODO: Make the framework aware of grouped data distribution + Nil + } else { + child.outputOrdering + } + } + + override def outputPartitioning: Partitioning = child.outputPartitioning + + def getWindowGroupLimitRel( + context: SubstraitContext, + originalInputAttributes: Seq[Attribute], + operatorId: Long, + input: RelNode, + validation: Boolean): RelNode = { + val args = context.registeredFunction + // Partition By Expressions + val partitionsExpressions = partitionSpec + .map( + ExpressionConverter + .replaceWithExpressionTransformer(_, attributeSeq = child.output) + .doTransform(args)) + .asJava + + // Sort By Expressions + val sortFieldList = + orderSpec.map { + order => + val builder = SortField.newBuilder() + val exprNode = ExpressionConverter + .replaceWithExpressionTransformer(order.child, attributeSeq = child.output) + .doTransform(args) + builder.setExpr(exprNode.toProtobuf) + builder.setDirectionValue(SortExecTransformer.transformSortDirection(order)) + builder.build() + }.asJava + if (!validation) { + val windowFunction = rankLikeFunction match { + case _: RowNumber => ExpressionNames.ROW_NUMBER + case _: Rank => ExpressionNames.RANK + case _: DenseRank => ExpressionNames.DENSE_RANK + case _ => throw new GlutenNotSupportException(s"Unknow window function $rankLikeFunction") + } + val parametersStr = new StringBuffer("WindowGroupLimitParameters:") + parametersStr + .append("window_function=") + .append(windowFunction) + .append("\n") + val message = StringValue.newBuilder().setValue(parametersStr.toString).build() + val extensionNode = ExtensionBuilder.makeAdvancedExtension( + BackendsApiManager.getTransformerApiInstance.packPBMessage(message), + null) + RelBuilder.makeWindowGroupLimitRel( + input, + partitionsExpressions, + sortFieldList, + limit, + extensionNode, + context, + operatorId) + } else { + // Use a extension node to send the input types through Substrait plan for validation. + val inputTypeNodeList = originalInputAttributes + .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) + .asJava + val extensionNode = ExtensionBuilder.makeAdvancedExtension( + BackendsApiManager.getTransformerApiInstance.packPBMessage( + TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) + + RelBuilder.makeWindowGroupLimitRel( + input, + partitionsExpressions, + sortFieldList, + limit, + extensionNode, + context, + operatorId) + } + } + + override protected def doValidateInternal(): ValidationResult = { + if (!BackendsApiManager.getSettings.supportWindowGroupLimitExec(rankLikeFunction)) { + return ValidationResult + .failed(s"Found unsupported rank like function: $rankLikeFunction") + } + val substraitContext = new SubstraitContext + val operatorId = substraitContext.nextOperatorId(this.nodeName) + + val relNode = + getWindowGroupLimitRel(substraitContext, child.output, operatorId, null, validation = true) + + doNativeValidation(substraitContext, relNode) + } + + override protected def doTransform(context: SubstraitContext): TransformContext = { + val childCtx = child.asInstanceOf[TransformSupport].transform(context) + val operatorId = context.nextOperatorId(this.nodeName) + + val currRel = + getWindowGroupLimitRel(context, child.output, operatorId, childCtx.root, validation = false) + assert(currRel != null, "Window Group Limit Rel should be valid") + TransformContext(childCtx.outputAttributes, output, currRel) + } +} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala index 03b26fa985ea..abb7d27ffe92 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala @@ -62,7 +62,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite }) protected def fallbackSets(isAqe: Boolean): Set[Int] = { - if (isSparkVersionGE("3.5")) Set(44, 67, 70) else Set.empty[Int] + Set.empty[Int] } protected def excludedTpcdsQueries: Set[String] = Set( "q66" // inconsistent results diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala index f7cf0de3762d..9ac35441acb4 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -1855,7 +1855,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | ) t1 |) t2 where rank = 1 """.stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }, isSparkVersionLE("3.3")) + compareResultsAgainstVanillaSpark(sql, true, { _ => }) } test("GLUTEN-1874 not null in both streams") { @@ -1873,7 +1873,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | ) t1 |) t2 where rank = 1 """.stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }, isSparkVersionLE("3.3")) + compareResultsAgainstVanillaSpark(sql, true, { _ => }) } test("GLUTEN-2095: test cast(string as binary)") { @@ -2456,7 +2456,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | ) t1 |) t2 where rank = 1 order by p_partkey limit 100 |""".stripMargin - runQueryAndCompare(sql, noFallBack = isSparkVersionLE("3.3"))({ _ => }) + runQueryAndCompare(sql, noFallBack = true)({ _ => }) } test("GLUTEN-4190: crush on flattening a const null column") { diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index b702082d6ff3..94a214e5e571 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -53,6 +53,7 @@ #include #include #include +#include #include #include #include @@ -315,7 +316,6 @@ DB::Block BlockUtil::concatenateBlocksMemoryEfficiently(std::vector & return out; } - size_t PODArrayUtil::adjustMemoryEfficientSize(size_t n) { /// According to definition of DEFUALT_BLOCK_SIZE @@ -560,9 +560,7 @@ std::map BackendInitializerUtil::getBackendConfMap(std } std::vector BackendInitializerUtil::wrapDiskPathConfig( - const String & path_prefix, - const String & path_suffix, - Poco::Util::AbstractConfiguration & config) + const String & path_prefix, const String & path_suffix, Poco::Util::AbstractConfiguration & config) { std::vector changed_paths; if (path_prefix.empty() && path_suffix.empty()) @@ -657,9 +655,7 @@ DB::Context::ConfigurationPtr BackendInitializerUtil::initConfig(std::mapgetString("timezone"); const String mapped_timezone = DateTimeUtil::convertTimeZone(config_timezone); - if (0 != setenv("TZ", mapped_timezone.data(), 1)) // NOLINT(concurrency-mt-unsafe) // ok if not called concurrently with other setenv/getenv + if (0 + != setenv( + "TZ", mapped_timezone.data(), 1)) // NOLINT(concurrency-mt-unsafe) // ok if not called concurrently with other setenv/getenv throw Poco::Exception("Cannot setenv TZ variable"); tzset(); @@ -807,8 +805,7 @@ void BackendInitializerUtil::initSettings(std::map & b { auto mem_gb = task_memory / static_cast(1_GiB); // 2.8x+5, Heuristics calculate the block size of external sort, [8,16] - settings.prefer_external_sort_block_bytes = std::max(std::min( - static_cast(2.8*mem_gb + 5), 16ul), 8ul) * 1024 * 1024; + settings.prefer_external_sort_block_bytes = std::max(std::min(static_cast(2.8 * mem_gb + 5), 16ul), 8ul) * 1024 * 1024; } } } @@ -848,10 +845,14 @@ void BackendInitializerUtil::initContexts(DB::Context::ConfigurationPtr config) global_context->setMarkCache(mark_cache_policy, mark_cache_size, mark_cache_size_ratio); - String index_uncompressed_cache_policy = config->getString("index_uncompressed_cache_policy", DEFAULT_INDEX_UNCOMPRESSED_CACHE_POLICY); - size_t index_uncompressed_cache_size = config->getUInt64("index_uncompressed_cache_size", DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE); - double index_uncompressed_cache_size_ratio = config->getDouble("index_uncompressed_cache_size_ratio", DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO); - global_context->setIndexUncompressedCache(index_uncompressed_cache_policy, index_uncompressed_cache_size, index_uncompressed_cache_size_ratio); + String index_uncompressed_cache_policy + = config->getString("index_uncompressed_cache_policy", DEFAULT_INDEX_UNCOMPRESSED_CACHE_POLICY); + size_t index_uncompressed_cache_size + = config->getUInt64("index_uncompressed_cache_size", DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE); + double index_uncompressed_cache_size_ratio + = config->getDouble("index_uncompressed_cache_size_ratio", DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO); + global_context->setIndexUncompressedCache( + index_uncompressed_cache_policy, index_uncompressed_cache_size, index_uncompressed_cache_size_ratio); String index_mark_cache_policy = config->getString("index_mark_cache_policy", DEFAULT_INDEX_MARK_CACHE_POLICY); size_t index_mark_cache_size = config->getUInt64("index_mark_cache_size", DEFAULT_INDEX_MARK_CACHE_MAX_SIZE); @@ -1023,11 +1024,13 @@ void BackendFinalizerUtil::finalizeGlobally() StorageMergeTreeFactory::clear(); QueryContext::resetGlobal(); std::lock_guard lock(paths_mutex); - std::ranges::for_each(paths_need_to_clean, [](const auto & path) - { - if (fs::exists(path)) - fs::remove_all(path); - }); + std::ranges::for_each( + paths_need_to_clean, + [](const auto & path) + { + if (fs::exists(path)) + fs::remove_all(path); + }); paths_need_to_clean.clear(); } diff --git a/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp b/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp index f2d4bc8a865d..ecb027c18f0a 100644 --- a/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp +++ b/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp @@ -32,16 +32,14 @@ namespace local_engine { static DB::ITransformingStep::Traits getTraits() { - return DB::ITransformingStep::Traits - { + return DB::ITransformingStep::Traits{ { .preserves_number_of_streams = true, .preserves_sorting = false, }, { .preserves_number_of_rows = false, - } - }; + }}; } ReplicateRowsStep::ReplicateRowsStep(const DB::DataStream & input_stream) @@ -49,7 +47,7 @@ ReplicateRowsStep::ReplicateRowsStep(const DB::DataStream & input_stream) { } -DB::Block ReplicateRowsStep::transformHeader(const DB::Block& input) +DB::Block ReplicateRowsStep::transformHeader(const DB::Block & input) { DB::Block output; for (int i = 1; i < input.columns(); i++) @@ -59,15 +57,9 @@ DB::Block ReplicateRowsStep::transformHeader(const DB::Block& input) return output; } -void ReplicateRowsStep::transformPipeline( - DB::QueryPipelineBuilder & pipeline, - const DB::BuildQueryPipelineSettings & /*settings*/) +void ReplicateRowsStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & /*settings*/) { - pipeline.addSimpleTransform( - [&](const DB::Block & header) - { - return std::make_shared(header); - }); + pipeline.addSimpleTransform([&](const DB::Block & header) { return std::make_shared(header); }); } void ReplicateRowsStep::updateOutputStream() @@ -105,4 +97,4 @@ void ReplicateRowsTransform::transform(DB::Chunk & chunk) chunk.setColumns(std::move(mutable_columns), total_rows); } -} \ No newline at end of file +} diff --git a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp new file mode 100644 index 000000000000..af04ef579028 --- /dev/null +++ b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "WindowGroupLimitStep.h" +#include +#include +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ +extern const int LOGICAL_ERROR; +} + +namespace local_engine +{ + +enum class WindowGroupLimitFunction +{ + RowNumber, + Rank, + DenseRank +}; + + +template +class WindowGroupLimitTransform : public DB::IProcessor +{ +public: + using Status = DB::IProcessor::Status; + explicit WindowGroupLimitTransform( + const DB::Block & header_, const std::vector & partition_columns_, const std::vector & sort_columns_, size_t limit_) + : DB::IProcessor({header_}, {header_}) + , header(header_) + , partition_columns(partition_columns_) + , sort_columns(sort_columns_) + , limit(limit_) + + { + } + ~WindowGroupLimitTransform() override = default; + String getName() const override { return "WindowGroupLimitTransform"; } + + Status prepare() override + { + auto & output_port = outputs.front(); + auto & input_port = inputs.front(); + if (output_port.isFinished()) + { + input_port.close(); + return Status::Finished; + } + + if (has_output) + { + if (output_port.canPush()) + { + output_port.push(std::move(output_chunk)); + has_output = false; + } + return Status::PortFull; + } + + if (has_input) + return Status::Ready; + + if (input_port.isFinished()) + { + output_port.finish(); + return Status::Finished; + } + input_port.setNeeded(); + if (!input_port.hasData()) + return Status::NeedData; + input_chunk = input_port.pull(true); + has_input = true; + return Status::Ready; + } + + void work() override + { + if (!has_input) [[unlikely]] + { + return; + } + DB::Block block = header.cloneWithColumns(input_chunk.getColumns()); + size_t partition_start_row = 0; + size_t chunk_rows = input_chunk.getNumRows(); + while (partition_start_row < chunk_rows) + { + auto next_partition_start_row = advanceNextPartition(input_chunk, partition_start_row); + iteratePartition(input_chunk, partition_start_row, next_partition_start_row); + partition_start_row = next_partition_start_row; + // corner case, the partition end row is the last row of chunk. + if (partition_start_row < chunk_rows) + { + current_row_rank_value = 1; + if constexpr (function == WindowGroupLimitFunction::Rank) + current_peer_group_rows = 0; + partition_start_row_columns = extractOneRowColumns(input_chunk, partition_start_row); + } + } + + if (!output_columns.empty() && output_columns[0]->size() > 0) + { + auto rows = output_columns[0]->size(); + output_chunk = DB::Chunk(std::move(output_columns), rows); + output_columns.clear(); + has_output = true; + } + has_input = false; + } + +private: + DB::Block header; + // Which columns are used as the partition keys + std::vector partition_columns; + // which columns are used as the order by keys, excluding partition columns. + std::vector sort_columns; + // Limitations for each partition. + size_t limit = 0; + + bool has_input = false; + DB::Chunk input_chunk; + bool has_output = false; + DB::MutableColumns output_columns; + DB::Chunk output_chunk; + + // We don't have window frame here. in fact all of frame are (unbounded preceding, current row] + // the start value is 1 + size_t current_row_rank_value = 1; + // rank need this to record how many rows in current peer group. + // A peer group in a partition is defined as the rows have the same value on the sort columns. + size_t current_peer_group_rows = 0; + + DB::Columns partition_start_row_columns; + DB::Columns peer_group_start_row_columns; + + + size_t advanceNextPartition(const DB::Chunk & chunk, size_t start_offset) + { + if (partition_start_row_columns.empty()) + partition_start_row_columns = extractOneRowColumns(chunk, start_offset); + + size_t max_row = chunk.getNumRows(); + for (size_t i = start_offset; i < max_row; ++i) + { + if (!isRowEqual(partition_columns, partition_start_row_columns, 0, chunk.getColumns(), i)) + { + return i; + } + } + return max_row; + } + + static DB::Columns extractOneRowColumns(const DB::Chunk & chunk, size_t offset) + { + DB::Columns row; + for (const auto & col : chunk.getColumns()) + { + auto new_col = col->cloneEmpty(); + new_col->insertFrom(*col, offset); + row.push_back(std::move(new_col)); + } + return row; + } + + static bool isRowEqual( + const std::vector & fields, const DB::Columns & left_cols, size_t loffset, const DB::Columns & right_cols, size_t roffset) + { + for (size_t i = 0; i < fields.size(); ++i) + { + const auto & field = fields[i]; + /// don't care about nan_direction_hint + if (left_cols[field]->compareAt(loffset, roffset, *right_cols[field], 1)) + return false; + } + return true; + } + + void iteratePartition(const DB::Chunk & chunk, size_t start_offset, size_t end_offset) + { + // Skip the rest rows int this partition. + if (current_row_rank_value > limit) + return; + + + size_t chunk_rows = chunk.getNumRows(); + auto has_peer_group_ended = [&](size_t offset, size_t partition_end_offset, size_t chunk_rows_) + { return offset < partition_end_offset || end_offset < chunk_rows_; }; + auto try_end_peer_group + = [&](size_t peer_group_start_offset, size_t next_peer_group_start_offset, size_t partition_end_offset, size_t chunk_rows_) + { + if constexpr (function == WindowGroupLimitFunction::Rank) + { + current_peer_group_rows += next_peer_group_start_offset - peer_group_start_offset; + if (has_peer_group_ended(next_peer_group_start_offset, partition_end_offset, chunk_rows_)) + { + current_row_rank_value += current_peer_group_rows; + current_peer_group_rows = 0; + peer_group_start_row_columns = extractOneRowColumns(chunk, next_peer_group_start_offset); + } + } + else if constexpr (function == WindowGroupLimitFunction::DenseRank) + { + if (has_peer_group_ended(next_peer_group_start_offset, partition_end_offset, chunk_rows_)) + { + current_row_rank_value += 1; + peer_group_start_row_columns = extractOneRowColumns(chunk, next_peer_group_start_offset); + } + } + }; + + // This is a corner case. prev partition's last row is the last row of a chunk. + if (start_offset >= end_offset) + { + assert(!start_offset); + try_end_peer_group(start_offset, end_offset, end_offset, chunk_rows); + return; + } + + // row_number is simple + if constexpr (function == WindowGroupLimitFunction::RowNumber) + { + size_t rows = end_offset - start_offset; + size_t limit_remained = limit - current_row_rank_value + 1; + rows = rows > limit_remained ? limit_remained : rows; + insertResultValue(chunk, start_offset, rows); + current_row_rank_value += rows; + } + else + { + size_t peer_group_start_offset = start_offset; + while (peer_group_start_offset < end_offset && current_row_rank_value <= limit) + { + auto next_peer_group_start_offset = advanceNextPeerGroup(chunk, peer_group_start_offset, end_offset); + + insertResultValue(chunk, peer_group_start_offset, next_peer_group_start_offset - peer_group_start_offset); + try_end_peer_group(peer_group_start_offset, next_peer_group_start_offset, end_offset, chunk_rows); + peer_group_start_offset = next_peer_group_start_offset; + } + } + } + void insertResultValue(const DB::Chunk & chunk, size_t start_offset, size_t rows) + { + if (!rows) + return; + if (output_columns.empty()) + { + for (const auto & col : chunk.getColumns()) + { + output_columns.push_back(col->cloneEmpty()); + } + } + size_t i = 0; + for (const auto & col : chunk.getColumns()) + { + output_columns[i]->insertRangeFrom(*col, start_offset, rows); + i += 1; + } + } + size_t advanceNextPeerGroup(const DB::Chunk & chunk, size_t start_offset, size_t partition_end_offset) + { + if (peer_group_start_row_columns.empty()) + peer_group_start_row_columns = extractOneRowColumns(chunk, start_offset); + for (size_t i = start_offset; i < partition_end_offset; ++i) + { + if (!isRowEqual(sort_columns, peer_group_start_row_columns, 0, chunk.getColumns(), i)) + { + return i; + } + } + return partition_end_offset; + } +}; + +static DB::ITransformingStep::Traits getTraits() +{ + return DB::ITransformingStep::Traits{ + { + .preserves_number_of_streams = false, + .preserves_sorting = true, + }, + { + .preserves_number_of_rows = false, + }}; +} + +WindowGroupLimitStep::WindowGroupLimitStep( + const DB::DataStream & input_stream_, + const String & function_name_, + const std::vector partition_columns_, + const std::vector sort_columns_, + size_t limit_) + : DB::ITransformingStep(input_stream_, input_stream_.header, getTraits()) + , function_name(function_name_) + , partition_columns(partition_columns_) + , sort_columns(sort_columns_) + , limit(limit_) +{ +} + +void WindowGroupLimitStep::describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const +{ + if (!processors.empty()) + DB::IQueryPlanStep::describePipeline(processors, settings); +} + +void WindowGroupLimitStep::updateOutputStream() +{ + output_stream = createOutputStream(input_streams.front(), input_streams.front().header, getDataStreamTraits()); +} + + +void WindowGroupLimitStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & /*settings*/) +{ + if (function_name == "row_number") + { + pipeline.addSimpleTransform( + [&](const DB::Block & header) + { + return std::make_shared>( + header, partition_columns, sort_columns, limit); + }); + } + else if (function_name == "rank") + { + pipeline.addSimpleTransform( + [&](const DB::Block & header) { + return std::make_shared>( + header, partition_columns, sort_columns, limit); + }); + } + else if (function_name == "dense_rank") + { + pipeline.addSimpleTransform( + [&](const DB::Block & header) + { + return std::make_shared>( + header, partition_columns, sort_columns, limit); + }); + } + else + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupport function {} in WindowGroupLimit", function_name); + } +} +} diff --git a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h new file mode 100644 index 000000000000..bbbbf42abc55 --- /dev/null +++ b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include + +namespace local_engine +{ +class WindowGroupLimitStep : public DB::ITransformingStep +{ +public: + explicit WindowGroupLimitStep( + const DB::DataStream & input_stream_, + const String & function_name_, + const std::vector partition_columns_, + const std::vector sort_columns_, + size_t limit_); + ~WindowGroupLimitStep() override = default; + + String getName() const override { return "WindowGroupLimitStep"; } + + void transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & settings) override; + void describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const override; + void updateOutputStream() override; + +private: + // window function name, one of row_number, rank and dense_rank + String function_name; + std::vector partition_columns; + std::vector sort_columns; + size_t limit; +}; + +} diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp index 42d4f4d4d8cd..cc7738a15aaa 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp @@ -14,25 +14,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include "AdvancedParametersParseUtil.h" #include #include -#include -#include +#include #include +#include #include + namespace DB::ErrorCodes { - extern const int BAD_ARGUMENTS; +extern const int BAD_ARGUMENTS; } namespace local_engine { -template +template void tryAssign(const std::unordered_map & kvs, const String & key, T & v); -template<> +template <> void tryAssign(const std::unordered_map & kvs, const String & key, String & v) { auto it = kvs.find(key); @@ -40,7 +41,7 @@ void tryAssign(const std::unordered_map & kvs, const Str v = it->second; } -template<> +template <> void tryAssign(const std::unordered_map & kvs, const String & key, bool & v) { auto it = kvs.find(key); @@ -57,7 +58,7 @@ void tryAssign(const std::unordered_map & kvs, const Strin } } -template<> +template <> void tryAssign(const std::unordered_map & kvs, const String & key, Int64 & v) { auto it = kvs.find(key); @@ -94,9 +95,9 @@ void readStringUntilCharsInto(String & s, DB::ReadBuffer & buf) std::unordered_map> convertToKVs(const String & advance) { std::unordered_map> res; - std::unordered_map *kvs; + std::unordered_map * kvs; DB::ReadBufferFromString in(advance); - while(!in.eof()) + while (!in.eof()) { String key; readStringUntilCharsInto<'=', '\n', ':'>(key, in); @@ -146,5 +147,13 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const String & advance) tryAssign(kvs, "numPartitions", info.partitions_num); return info; } -} +WindowGroupOptimizationInfo WindowGroupOptimizationInfo::parse(const String & advance) +{ + WindowGroupOptimizationInfo info; + auto kkvs = convertToKVs(advance); + auto & kvs = kkvs["WindowGroupLimitParameters"]; + tryAssign(kvs, "window_function", info.window_function); + return info; +} +} diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h index 5f6fe6d256e3..fc478db33bfd 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h @@ -16,10 +16,10 @@ */ #pragma once #include +#include namespace local_engine { - std::unordered_map> convertToKVs(const String & advance); @@ -38,5 +38,10 @@ struct JoinOptimizationInfo static JoinOptimizationInfo parse(const String & advance); }; -} +struct WindowGroupOptimizationInfo +{ + String window_function; + static WindowGroupOptimizationInfo parse(const String & advnace); +}; +} diff --git a/cpp-ch/local-engine/Parser/RelParser.cpp b/cpp-ch/local-engine/Parser/RelParser.cpp index f651146a391d..a7f6d0586455 100644 --- a/cpp-ch/local-engine/Parser/RelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParser.cpp @@ -30,8 +30,8 @@ namespace DB { namespace ErrorCodes { - extern const int BAD_ARGUMENTS; - extern const int LOGICAL_ERROR; +extern const int BAD_ARGUMENTS; +extern const int LOGICAL_ERROR; } } @@ -89,14 +89,15 @@ DB::QueryPlanPtr RelParser::parseOp(const substrait::Rel & rel, std::list RelParser::parseFormattedRelAdvancedOptimization(const substrait::extensions::AdvancedExtension &advanced_extension) +std::map +RelParser::parseFormattedRelAdvancedOptimization(const substrait::extensions::AdvancedExtension & advanced_extension) { std::map configs; if (advanced_extension.has_optimization()) { google::protobuf::StringValue msg; advanced_extension.optimization().UnpackTo(&msg); - Poco::StringTokenizer kvs( msg.value(), "\n"); + Poco::StringTokenizer kvs(msg.value(), "\n"); for (auto & kv : kvs) { if (kv.empty()) @@ -114,7 +115,8 @@ std::map RelParser::parseFormattedRelAdvancedOptimizat return configs; } -std::string RelParser::getStringConfig(const std::map & configs, const std::string & key, const std::string & default_value) +std::string +RelParser::getStringConfig(const std::map & configs, const std::string & key, const std::string & default_value) { auto it = configs.find(key); if (it == configs.end()) @@ -150,6 +152,7 @@ RelParserFactory::RelParserBuilder RelParserFactory::getBuilder(UInt32 k) } void registerWindowRelParser(RelParserFactory & factory); +void registerWindowGroupLimitRelParser(RelParserFactory & factory); void registerSortRelParser(RelParserFactory & factory); void registerExpandRelParser(RelParserFactory & factory); void registerAggregateParser(RelParserFactory & factory); @@ -162,6 +165,7 @@ void registerRelParsers() { auto & factory = RelParserFactory::instance(); registerWindowRelParser(factory); + registerWindowGroupLimitRelParser(factory); registerSortRelParser(factory); registerExpandRelParser(factory); registerAggregateParser(factory); diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 8efbd97d240d..fc0650993766 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -389,9 +389,9 @@ void adjustOutput(const DB::QueryPlanPtr & query_plan, const substrait::PlanRel } if (need_final_project) { - ActionsDAG final_project - = ActionsDAG::makeConvertingActions(original_cols, final_cols, ActionsDAG::MatchColumnsMode::Position); - QueryPlanStepPtr final_project_step = std::make_unique(query_plan->getCurrentDataStream(), std::move(final_project)); + ActionsDAG final_project = ActionsDAG::makeConvertingActions(original_cols, final_cols, ActionsDAG::MatchColumnsMode::Position); + QueryPlanStepPtr final_project_step + = std::make_unique(query_plan->getCurrentDataStream(), std::move(final_project)); final_project_step->setStepDescription("Project for output schema"); query_plan->addStep(std::move(final_project_step)); } @@ -499,6 +499,7 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel, std::list case substrait::Rel::RelTypeCase::kWindow: case substrait::Rel::RelTypeCase::kJoin: case substrait::Rel::RelTypeCase::kCross: + case substrait::Rel::RelTypeCase::kWindowGroupLimit: case substrait::Rel::RelTypeCase::kExpand: { auto op_parser = RelParserFactory::instance().getBuilder(rel.rel_type_case())(this); query_plan = op_parser->parseOp(rel, rel_stack); @@ -601,7 +602,7 @@ void SerializedPlanParser::parseArrayJoinArguments( } ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( - const substrait::Expression & rel, std::vector & result_names, ActionsDAG& actions_dag, bool keep_result, bool position) + const substrait::Expression & rel, std::vector & result_names, ActionsDAG & actions_dag, bool keep_result, bool position) { if (!rel.has_scalar_function()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "The root of expression should be a scalar function:\n {}", rel.DebugString()); @@ -718,7 +719,7 @@ ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( } const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( - const substrait::Expression & rel, std::string & result_name, ActionsDAG& actions_dag, bool keep_result) + const substrait::Expression & rel, std::string & result_name, ActionsDAG & actions_dag, bool keep_result) { if (!rel.has_scalar_function()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "the root of expression should be a scalar function:\n {}", rel.DebugString()); @@ -780,9 +781,8 @@ bool SerializedPlanParser::isFunction(substrait::Expression_ScalarFunction rel, } void SerializedPlanParser::parseFunctionOrExpression( - const substrait::Expression & rel, std::string & result_name, ActionsDAG& actions_dag, bool keep_result) + const substrait::Expression & rel, std::string & result_name, ActionsDAG & actions_dag, bool keep_result) { - if (rel.has_scalar_function()) parseFunctionWithDAG(rel, result_name, actions_dag, keep_result); else @@ -793,11 +793,7 @@ void SerializedPlanParser::parseFunctionOrExpression( } void SerializedPlanParser::parseJsonTuple( - const substrait::Expression & rel, - std::vector & result_names, - ActionsDAG& actions_dag, - bool keep_result, - bool) + const substrait::Expression & rel, std::vector & result_names, ActionsDAG & actions_dag, bool keep_result, bool) { const auto & scalar_function = rel.scalar_function(); auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference())); @@ -856,7 +852,7 @@ void SerializedPlanParser::parseJsonTuple( } const ActionsDAG::Node * -SerializedPlanParser::toFunctionNode(ActionsDAG& actions_dag, const String & function, const ActionsDAG::NodeRawConstPtrs & args) +SerializedPlanParser::toFunctionNode(ActionsDAG & actions_dag, const String & function, const ActionsDAG::NodeRawConstPtrs & args) { auto function_builder = FunctionFactory::instance().get(function, context); std::string args_name = join(args, ','); @@ -1068,7 +1064,7 @@ std::pair SerializedPlanParser::parseLiteral(const substrait return std::make_pair(std::move(type), std::move(field)); } -const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAG& actions_dag, const substrait::Expression & rel) +const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAG & actions_dag, const substrait::Expression & rel) { switch (rel.rex_type_case()) { @@ -1533,8 +1529,7 @@ ASTPtr ASTParser::parseArgumentToAST(const Names & names, const substrait::Expre } } -void SerializedPlanParser::removeNullableForRequiredColumns( - const std::set & require_columns, ActionsDAG & actions_dag) const +void SerializedPlanParser::removeNullableForRequiredColumns(const std::set & require_columns, ActionsDAG & actions_dag) const { for (const auto & item : require_columns) { @@ -1549,7 +1544,7 @@ void SerializedPlanParser::removeNullableForRequiredColumns( } void SerializedPlanParser::wrapNullable( - const std::vector & columns, ActionsDAG& actions_dag, std::map & nullable_measure_names) + const std::vector & columns, ActionsDAG & actions_dag, std::map & nullable_measure_names) { for (const auto & item : columns) { diff --git a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp new file mode 100644 index 000000000000..f6c10386f405 --- /dev/null +++ b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "WindowGroupLimitRelParser.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "AdvancedParametersParseUtil.h" + +namespace DB::ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} + +namespace local_engine +{ +WindowGroupLimitRelParser::WindowGroupLimitRelParser(SerializedPlanParser * plan_parser_) : RelParser(plan_parser_) +{ +} + +DB::QueryPlanPtr +WindowGroupLimitRelParser::parse(DB::QueryPlanPtr current_plan_, const substrait::Rel & rel, std::list & rel_stack_) +{ + const auto win_rel_def = rel.windowgrouplimit(); + google::protobuf::StringValue optimize_info_str; + optimize_info_str.ParseFromString(win_rel_def.advanced_extension().optimization().value()); + auto optimization_info = WindowGroupOptimizationInfo::parse(optimize_info_str.value()); + window_function_name = optimization_info.window_function; + + current_plan = std::move(current_plan_); + + auto partition_fields = parsePartitoinFields(win_rel_def.partition_expressions()); + auto sort_fields = parseSortFields(win_rel_def.sorts()); + size_t limit = static_cast(win_rel_def.limit()); + + auto window_group_limit_step = std::make_unique( + current_plan->getCurrentDataStream(), window_function_name, partition_fields, sort_fields, limit); + window_group_limit_step->setStepDescription("Window group limit"); + steps.emplace_back(window_group_limit_step.get()); + current_plan->addStep(std::move(window_group_limit_step)); + + return std::move(current_plan); +} + +std::vector +WindowGroupLimitRelParser::parsePartitoinFields(const google::protobuf::RepeatedPtrField & expressions) +{ + std::vector fields; + for (const auto & expr : expressions) + { + if (expr.has_selection()) + { + fields.push_back(static_cast(expr.selection().direct_reference().struct_field().field())); + } + else if (expr.has_literal()) + { + continue; + } + else + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow expression: {}", expr.DebugString()); + } + } + return fields; +} + +std::vector WindowGroupLimitRelParser::parseSortFields(const google::protobuf::RepeatedPtrField & sort_fields) +{ + std::vector fields; + for (const auto sort_field : sort_fields) + { + if (sort_field.expr().has_literal()) + { + continue; + } + else if (sort_field.expr().has_selection()) + { + fields.push_back(static_cast(sort_field.expr().selection().direct_reference().struct_field().field())); + } + else + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown expression: {}", sort_field.expr().DebugString()); + } + } + return fields; +} + +void registerWindowGroupLimitRelParser(RelParserFactory & factory) +{ + auto builder = [](SerializedPlanParser * plan_parser) { return std::make_shared(plan_parser); }; + factory.registerBuilder(substrait::Rel::RelTypeCase::kWindowGroupLimit, builder); +} +} diff --git a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h new file mode 100644 index 000000000000..c9c503ed4745 --- /dev/null +++ b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ +/// Similar to WindowRelParser. Some differences +/// 1. cannot support aggregate functions. only support window functions: row_number, rank, dense_rank +/// 2. row_number, rank and dense_rank are mapped to new variants +/// 3. the output columns don't contain window function results +class WindowGroupLimitRelParser : public RelParser +{ +public: + explicit WindowGroupLimitRelParser(SerializedPlanParser * plan_parser_); + ~WindowGroupLimitRelParser() override = default; + DB::QueryPlanPtr + parse(DB::QueryPlanPtr current_plan_, const substrait::Rel & rel, std::list & rel_stack_) override; + const substrait::Rel & getSingleInput(const substrait::Rel & rel) override { return rel.windowgrouplimit().input(); } + +private: + DB::QueryPlanPtr current_plan; + String window_function_name; + + std::vector parsePartitoinFields(const google::protobuf::RepeatedPtrField & expressions); + std::vector parseSortFields(const google::protobuf::RepeatedPtrField & sort_fields); +}; +} diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index a55926d76d12..dd4150806cfc 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ArrowEvalPythonExec +import org.apache.spark.sql.execution.window._ import org.apache.spark.sql.hive.{HiveTableScanExecTransformer, HiveUDFTransformer} import org.apache.spark.sql.types.{DecimalType, LongType, NullType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -678,6 +679,15 @@ trait SparkPlanExecApi { } } + def genWindowGroupLimitTransformer( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + rankLikeFunction: Expression, + limit: Int, + mode: WindowGroupLimitMode, + child: SparkPlan): SparkPlan = + WindowGroupLimitExecTransformer(partitionSpec, orderSpec, rankLikeFunction, limit, mode, child) + def genHiveUDFTransformer( expr: Expression, attributeSeq: Seq[Attribute]): ExpressionTransformer = { diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala index 6047789e6abe..5440481f8ddd 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala @@ -358,7 +358,7 @@ object OffloadOthers { val windowGroupLimitPlan = SparkShimLoader.getSparkShims .getWindowGroupLimitExecShim(plan) .asInstanceOf[WindowGroupLimitExecShim] - WindowGroupLimitExecTransformer( + BackendsApiManager.getSparkPlanExecApiInstance.genWindowGroupLimitTransformer( windowGroupLimitPlan.partitionSpec, windowGroupLimitPlan.orderSpec, windowGroupLimitPlan.rankLikeFunction,