diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala index 73b2d0f21101..a0576a807b98 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala @@ -450,6 +450,14 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil { s"SampleTransformer metrics update is not supported in CH backend") } + override def genUnionTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = + throw new UnsupportedOperationException( + "UnionExecTransformer metrics update is not supported in CH backend") + + override def genUnionTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater = + throw new UnsupportedOperationException( + "UnionExecTransformer metrics update is not supported in CH backend") + def genWriteFilesTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = Map( "physicalWrittenBytes" -> SQLMetrics.createMetric(sparkContext, "number of written bytes"), diff --git a/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java b/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java index 9228a2f860ae..856ddf159730 100644 --- a/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java +++ b/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java @@ -31,7 +31,7 @@ public class GlutenURLDecoder { *

Note: The World Wide Web Consortium * Recommendation states that UTF-8 should be used. Not doing so may introduce - * incompatibilites. + * incompatibilities. * * @param s the String to decode * @param enc The name of a supported character diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala index d29d3029709e..3a82abe61833 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala @@ -31,6 +31,7 @@ import org.apache.spark.{HdfsConfGenerator, SparkConf, SparkContext} import org.apache.spark.api.plugin.PluginContext import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit +import org.apache.spark.sql.execution.ColumnarCachedBatchSerializer import org.apache.spark.sql.execution.datasources.GlutenWriterColumnarRules import org.apache.spark.sql.execution.datasources.velox.{VeloxParquetWriterInjects, VeloxRowSplitter} import org.apache.spark.sql.expression.UDFResolver @@ -75,7 +76,7 @@ class VeloxListenerApi extends ListenerApi with Logging { if (conf.getBoolean(GlutenConfig.COLUMNAR_TABLE_CACHE_ENABLED.key, defaultValue = false)) { conf.set( StaticSQLConf.SPARK_CACHE_SERIALIZER.key, - "org.apache.spark.sql.execution.ColumnarCachedBatchSerializer") + classOf[ColumnarCachedBatchSerializer].getName) } // Static initializers for driver. diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala index e70e1d13bdfe..934b680382ea 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala @@ -582,4 +582,15 @@ class VeloxMetricsApi extends MetricsApi with Logging { override def genSampleTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater = new SampleMetricsUpdater(metrics) + + override def genUnionTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = Map( + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "inputVectors" -> SQLMetrics.createMetric(sparkContext, "number of input vectors"), + "inputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of input bytes"), + "wallNanos" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time of union"), + "cpuCount" -> SQLMetrics.createMetric(sparkContext, "cpu wall time count") + ) + + override def genUnionTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater = + new UnionMetricsUpdater(metrics) } diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index 7841e6cd94b1..7337be573710 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -92,6 +92,7 @@ object VeloxRuleApi { c => HeuristicTransform.Single(validatorBuilder(c.glutenConf), rewrites, offloads)) // Legacy: Post-transform rules. + injector.injectPostTransform(_ => UnionTransformerRule()) injector.injectPostTransform(c => PartialProjectRule.apply(c.session)) injector.injectPostTransform(_ => RemoveNativeWriteFilesSortAndProject()) injector.injectPostTransform(c => RewriteTransformer.apply(c.session)) @@ -178,6 +179,7 @@ object VeloxRuleApi { // Gluten RAS: Post rules. injector.injectPostTransform(_ => RemoveTransitions) + injector.injectPostTransform(_ => UnionTransformerRule()) injector.injectPostTransform(c => PartialProjectRule.apply(c.session)) injector.injectPostTransform(_ => RemoveNativeWriteFilesSortAndProject()) injector.injectPostTransform(c => RewriteTransformer.apply(c.session)) diff --git a/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala b/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala index cd50d0b8e20c..b8ef1620f905 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala @@ -58,7 +58,8 @@ object MetricsUtil extends Logging { assert(t.children.size == 1, "MetricsUpdater.None can only be used on unary operator") treeifyMetricsUpdaters(t.children.head) case t: TransformSupport => - MetricsUpdaterTree(t.metricsUpdater(), t.children.map(treeifyMetricsUpdaters)) + // Reversed children order to match the traversal code. + MetricsUpdaterTree(t.metricsUpdater(), t.children.reverse.map(treeifyMetricsUpdaters)) case _ => MetricsUpdaterTree(MetricsUpdater.Terminate, Seq()) } @@ -233,6 +234,12 @@ object MetricsUtil extends Logging { operatorMetrics, metrics.getSingleMetrics, joinParamsMap.get(operatorIdx)) + case u: UnionMetricsUpdater => + // JoinRel outputs two suites of metrics respectively for hash build and hash probe. + // Therefore, fetch one more suite of metrics here. + operatorMetrics.add(metrics.getOperatorMetrics(curMetricsIdx)) + curMetricsIdx -= 1 + u.updateUnionMetrics(operatorMetrics) case hau: HashAggregateMetricsUpdater => hau.updateAggregationMetrics(operatorMetrics, aggParamsMap.get(operatorIdx)) case lu: LimitMetricsUpdater => diff --git a/backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala b/backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala new file mode 100644 index 000000000000..9e91cf368c0a --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.metrics + +import org.apache.spark.sql.execution.metric.SQLMetric + +class UnionMetricsUpdater(val metrics: Map[String, SQLMetric]) extends MetricsUpdater { + override def updateNativeMetrics(opMetrics: IOperatorMetrics): Unit = { + throw new UnsupportedOperationException() + } + + def updateUnionMetrics(unionMetrics: java.util.ArrayList[OperatorMetrics]): Unit = { + // Union was interpreted to LocalExchange + LocalPartition. Use metrics from LocalExchange. + val localExchangeMetrics = unionMetrics.get(0) + metrics("numInputRows") += localExchangeMetrics.inputRows + metrics("inputVectors") += localExchangeMetrics.inputVectors + metrics("inputBytes") += localExchangeMetrics.inputBytes + metrics("cpuCount") += localExchangeMetrics.cpuCount + metrics("wallNanos") += localExchangeMetrics.wallNanos + } +} diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala index 5cb2b652604d..8063a5d12207 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala @@ -537,11 +537,37 @@ class MiscOperatorSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa |""".stripMargin) { df => { - getExecutedPlan(df).exists(plan => plan.find(_.isInstanceOf[ColumnarUnionExec]).isDefined) + assert( + getExecutedPlan(df).exists( + plan => plan.find(_.isInstanceOf[ColumnarUnionExec]).isDefined)) } } } + test("union_all two tables with known partitioning") { + withSQLConf(GlutenConfig.NATIVE_UNION_ENABLED.key -> "true") { + compareDfResultsAgainstVanillaSpark( + () => { + val df1 = spark.sql("select l_orderkey as orderkey from lineitem") + val df2 = spark.sql("select o_orderkey as orderkey from orders") + df1.repartition(5).union(df2.repartition(5)) + }, + compareResult = true, + checkGlutenOperatorMatch[UnionExecTransformer] + ) + + compareDfResultsAgainstVanillaSpark( + () => { + val df1 = spark.sql("select l_orderkey as orderkey from lineitem") + val df2 = spark.sql("select o_orderkey as orderkey from orders") + df1.repartition(5).union(df2.repartition(6)) + }, + compareResult = true, + checkGlutenOperatorMatch[ColumnarUnionExec] + ) + } + } + test("union_all three tables") { runQueryAndCompare(""" |select count(orderkey) from ( diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala index 24e04f2dfce3..6ac59ba4fa6b 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala @@ -255,7 +255,10 @@ class VeloxOrcDataTypeValidationSuite extends VeloxWholeStageTransformerSuite { |""".stripMargin) { df => { - assert(getExecutedPlan(df).exists(plan => plan.isInstanceOf[ColumnarUnionExec])) + assert( + getExecutedPlan(df).exists( + plan => + plan.isInstanceOf[ColumnarUnionExec] || plan.isInstanceOf[UnionExecTransformer])) } } diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala index 57ca448fec79..cb5614f39669 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala @@ -254,7 +254,10 @@ class VeloxParquetDataTypeValidationSuite extends VeloxWholeStageTransformerSuit |""".stripMargin) { df => { - assert(getExecutedPlan(df).exists(plan => plan.isInstanceOf[ColumnarUnionExec])) + assert( + getExecutedPlan(df).exists( + plan => + plan.isInstanceOf[ColumnarUnionExec] || plan.isInstanceOf[UnionExecTransformer])) } } diff --git a/cpp/velox/compute/WholeStageResultIterator.cc b/cpp/velox/compute/WholeStageResultIterator.cc index b6ecbd959f09..411c6c563646 100644 --- a/cpp/velox/compute/WholeStageResultIterator.cc +++ b/cpp/velox/compute/WholeStageResultIterator.cc @@ -91,7 +91,7 @@ WholeStageResultIterator::WholeStageResultIterator( std::move(queryCtx), velox::exec::Task::ExecutionMode::kSerial); if (!task_->supportSerialExecutionMode()) { - throw std::runtime_error("Task doesn't support single thread execution: " + planNode->toString()); + throw std::runtime_error("Task doesn't support single threaded execution: " + planNode->toString()); } auto fileSystem = velox::filesystems::getFileSystem(spillDir, nullptr); GLUTEN_CHECK(fileSystem != nullptr, "File System for spilling is null!"); @@ -248,15 +248,47 @@ void WholeStageResultIterator::getOrderedNodeIds( const std::shared_ptr& planNode, std::vector& nodeIds) { bool isProjectNode = (std::dynamic_pointer_cast(planNode) != nullptr); + bool isLocalExchangeNode = (std::dynamic_pointer_cast(planNode) != nullptr); + bool isUnionNode = isLocalExchangeNode && + std::dynamic_pointer_cast(planNode)->type() == + velox::core::LocalPartitionNode::Type::kGather; const auto& sourceNodes = planNode->sources(); - for (const auto& sourceNode : sourceNodes) { + if (isProjectNode) { + GLUTEN_CHECK(sourceNodes.size() == 1, "Illegal state"); + const auto sourceNode = sourceNodes.at(0); // Filter over Project are mapped into FilterProject operator in Velox. // Metrics are all applied on Project node, and the metrics for Filter node // do not exist. - if (isProjectNode && std::dynamic_pointer_cast(sourceNode)) { + if (std::dynamic_pointer_cast(sourceNode)) { omittedNodeIds_.insert(sourceNode->id()); } getOrderedNodeIds(sourceNode, nodeIds); + nodeIds.emplace_back(planNode->id()); + return; + } + + if (isUnionNode) { + // FIXME: The whole metrics system in gluten-substrait is magic. Passing metrics trees through JNI with a trivial + // array is possible but requires for a solid design. Apparently we haven't had it. All the code requires complete + // rework. + // Union was interpreted as LocalPartition + LocalExchange + 2 fake projects as children in Velox. So we only fetch + // metrics from the root node. + std::vector> unionChildren{}; + for (const auto& source : planNode->sources()) { + const auto projectedChild = std::dynamic_pointer_cast(source); + GLUTEN_CHECK(projectedChild != nullptr, "Illegal state"); + const auto projectSources = projectedChild->sources(); + GLUTEN_CHECK(projectSources.size() == 1, "Illegal state"); + const auto projectSource = projectSources.at(0); + getOrderedNodeIds(projectSource, nodeIds); + } + nodeIds.emplace_back(planNode->id()); + return; + } + + for (const auto& sourceNode : sourceNodes) { + // Post-order traversal. + getOrderedNodeIds(sourceNode, nodeIds); } nodeIds.emplace_back(planNode->id()); } @@ -350,9 +382,9 @@ void WholeStageResultIterator::collectMetrics() { continue; } - const auto& status = planStats.at(nodeId); - // Add each operator status into metrics. - for (const auto& entry : status.operatorStats) { + const auto& stats = planStats.at(nodeId); + // Add each operator stats into metrics. + for (const auto& entry : stats.operatorStats) { const auto& second = entry.second; metrics_->get(Metrics::kInputRows)[metricIndex] = second->inputRows; metrics_->get(Metrics::kInputVectors)[metricIndex] = second->inputVectors; diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index 1efa7338796d..3ceccca4a3de 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -1043,6 +1043,50 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan( childNode); } +core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::SetRel& setRel) { + switch (setRel.op()) { + case ::substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_UNION_ALL: { + std::vector children; + for (int32_t i = 0; i < setRel.inputs_size(); ++i) { + const auto& input = setRel.inputs(i); + children.push_back(toVeloxPlan(input)); + } + GLUTEN_CHECK(!children.empty(), "At least one source is required for Velox LocalPartition"); + + // Velox doesn't allow different field names in schemas of LocalPartitionNode's children. + // Add project nodes to unify the schemas. + const RowTypePtr outRowType = asRowType(children[0]->outputType()); + std::vector outNames; + for (int32_t colIdx = 0; colIdx < outRowType->size(); ++colIdx) { + const auto name = outRowType->childAt(colIdx)->name(); + outNames.push_back(name); + } + + std::vector projectedChildren; + for (int32_t i = 0; i < children.size(); ++i) { + const auto& child = children[i]; + const RowTypePtr& childRowType = child->outputType(); + std::vector expressions; + for (int32_t colIdx = 0; colIdx < outNames.size(); ++colIdx) { + const auto fa = + std::make_shared(childRowType->childAt(colIdx), childRowType->nameOf(colIdx)); + const auto cast = std::make_shared(outRowType->childAt(colIdx), fa, false); + expressions.push_back(cast); + } + auto project = std::make_shared(nextPlanNodeId(), outNames, expressions, child); + projectedChildren.push_back(project); + } + return std::make_shared( + nextPlanNodeId(), + core::LocalPartitionNode::Type::kGather, + std::make_shared(), + projectedChildren); + } + default: + throw GlutenException("Unsupported SetRel op: " + std::to_string(setRel.op())); + } +} + core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::SortRel& sortRel) { auto childNode = convertSingleInput<::substrait::SortRel>(sortRel); auto [sortingKeys, sortingOrders] = processSortField(sortRel.sorts(), childNode->outputType()); @@ -1298,6 +1342,8 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: return toVeloxPlan(rel.write()); } else if (rel.has_windowgrouplimit()) { return toVeloxPlan(rel.windowgrouplimit()); + } else if (rel.has_set()) { + return toVeloxPlan(rel.set()); } else { VELOX_NYI("Substrait conversion not supported for Rel."); } diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.h b/cpp/velox/substrait/SubstraitToVeloxPlan.h index 51e50ce34767..6121923df787 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.h +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.h @@ -84,6 +84,9 @@ class SubstraitToVeloxPlanConverter { /// Used to convert Substrait WindowGroupLimitRel into Velox PlanNode. core::PlanNodePtr toVeloxPlan(const ::substrait::WindowGroupLimitRel& windowGroupLimitRel); + /// Used to convert Substrait SetRel into Velox PlanNode. + core::PlanNodePtr toVeloxPlan(const ::substrait::SetRel& setRel); + /// Used to convert Substrait JoinRel into Velox PlanNode. core::PlanNodePtr toVeloxPlan(const ::substrait::JoinRel& joinRel); diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc index 3b74caf8ba5a..9325fed3217c 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc @@ -22,7 +22,6 @@ #include "TypeUtils.h" #include "udf/UdfLoader.h" #include "utils/Common.h" -#include "velox/core/ExpressionEvaluator.h" #include "velox/exec/Aggregate.h" #include "velox/expression/Expr.h" #include "velox/expression/SignatureBinder.h" @@ -30,7 +29,7 @@ namespace gluten { namespace { -static const char* extractFileName(const char* file) { +const char* extractFileName(const char* file) { return strrchr(file, '/') ? strrchr(file, '/') + 1 : file; } @@ -53,13 +52,13 @@ static const char* extractFileName(const char* file) { __FUNCTION__, \ reason)) -static const std::unordered_set kRegexFunctions = { +const std::unordered_set kRegexFunctions = { "regexp_extract", "regexp_extract_all", "regexp_replace", "rlike"}; -static const std::unordered_set kBlackList = { +const std::unordered_set kBlackList = { "split_part", "factorial", "concat_ws", @@ -70,32 +69,59 @@ static const std::unordered_set kBlackList = { "approx_percentile", "get_array_struct_fields", "map_from_arrays"}; - } // namespace -bool SubstraitToVeloxPlanValidator::validateInputTypes( +bool SubstraitToVeloxPlanValidator::parseVeloxType( const ::substrait::extensions::AdvancedExtension& extension, - std::vector& types) { + TypePtr& out) { + ::substrait::Type substraitType; // The input type is wrapped in enhancement. if (!extension.has_enhancement()) { LOG_VALIDATION_MSG("Input type is not wrapped in enhancement."); return false; } const auto& enhancement = extension.enhancement(); - ::substrait::Type inputType; - if (!enhancement.UnpackTo(&inputType)) { + if (!enhancement.UnpackTo(&substraitType)) { LOG_VALIDATION_MSG("Enhancement can't be unpacked to inputType."); return false; } - if (!inputType.has_struct_()) { - LOG_VALIDATION_MSG("Input type has no struct."); + + out = SubstraitParser::parseType(substraitType); + return true; +} + +bool SubstraitToVeloxPlanValidator::flattenVeloxType1(const TypePtr& type, std::vector& out) { + if (type->kind() != TypeKind::ROW) { + LOG_VALIDATION_MSG("Type is not a RowType."); + return false; + } + auto rowType = std::dynamic_pointer_cast(type); + if (!rowType) { + LOG_VALIDATION_MSG("Failed to cast to RowType."); return false; } + for (const auto& field : rowType->children()) { + out.emplace_back(field); + } + return true; +} - // Get the input types. - const auto& sTypes = inputType.struct_().types(); - for (const auto& sType : sTypes) { - types.emplace_back(SubstraitParser::parseType(sType)); +bool SubstraitToVeloxPlanValidator::flattenVeloxType2(const TypePtr& type, std::vector>& out) { + if (type->kind() != TypeKind::ROW) { + LOG_VALIDATION_MSG("Type is not a RowType."); + return false; + } + auto rowType = std::dynamic_pointer_cast(type); + if (!rowType) { + LOG_VALIDATION_MSG("Failed to cast to RowType."); + return false; + } + for (const auto& field : rowType->children()) { + std::vector inner; + if (!flattenVeloxType1(field, inner)) { + return false; + } + out.emplace_back(inner); } return true; } @@ -341,10 +367,11 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WriteRel& writeR } // Validate input data type. + TypePtr inputRowType; std::vector types; if (writeRel.has_named_table()) { const auto& extension = writeRel.named_table().advanced_extension(); - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input type validation in WriteRel."); return false; } @@ -380,12 +407,12 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WriteRel& writeR } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FetchRel& fetchRel) { - RowTypePtr rowType = nullptr; // Get and validate the input types from extension. if (fetchRel.has_advanced_extension()) { const auto& extension = fetchRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Unsupported input types in FetchRel."); return false; } @@ -396,7 +423,6 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FetchRel& fetchR for (auto colIdx = 0; colIdx < types.size(); colIdx++) { names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx)); } - rowType = std::make_shared(std::move(names), std::move(types)); } if (fetchRel.offset() < 0 || fetchRel.count() < 0) { @@ -412,8 +438,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::TopNRel& topNRel // Get and validate the input types from extension. if (topNRel.has_advanced_extension()) { const auto& extension = topNRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Unsupported input types in TopNRel."); return false; } @@ -457,8 +484,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::GenerateRel& gen return false; } const auto& extension = generateRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in GenerateRel."); return false; } @@ -487,8 +515,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ExpandRel& expan // Get and validate the input types from extension. if (expandRel.has_advanced_extension()) { const auto& extension = expandRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Unsupported input types in ExpandRel."); return false; } @@ -571,8 +600,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windo return false; } const auto& extension = windowRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in WindowRel."); return false; } @@ -680,7 +710,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windo LOG_VALIDATION_MSG("in windowRel, the sorting key in Sort Operator only support field."); return false; } - exec::ExprSet exprSet({std::move(expression)}, execCtx_); + exec::ExprSet exprSet1({std::move(expression)}, execCtx_); } } @@ -699,8 +729,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowGroupLimit return false; } const auto& extension = windowGroupLimitRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in WindowGroupLimitRel."); return false; } @@ -750,13 +781,61 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowGroupLimit LOG_VALIDATION_MSG("in windowGroupLimitRel, the sorting key in Sort Operator only support field."); return false; } - exec::ExprSet exprSet({std::move(expression)}, execCtx_); + exec::ExprSet exprSet1({std::move(expression)}, execCtx_); } } return true; } +bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SetRel& setRel) { + switch (setRel.op()) { + case ::substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_UNION_ALL: { + for (int32_t i = 0; i < setRel.inputs_size(); ++i) { + const auto& input = setRel.inputs(i); + if (!validate(input)) { + LOG_VALIDATION_MSG("ProjectRel input"); + return false; + } + } + if (!setRel.has_advanced_extension()) { + LOG_VALIDATION_MSG("Input types are expected in SetRel."); + return false; + } + const auto& extension = setRel.advanced_extension(); + TypePtr inputRowType; + std::vector> childrenTypes; + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType2(inputRowType, childrenTypes)) { + LOG_VALIDATION_MSG("Validation failed for input types in SetRel."); + return false; + } + std::vector childrenRowTypes; + for (auto i = 0; i < childrenTypes.size(); ++i) { + auto& types = childrenTypes.at(i); + std::vector names; + names.reserve(types.size()); + for (auto colIdx = 0; colIdx < types.size(); colIdx++) { + names.emplace_back(SubstraitParser::makeNodeName(i, colIdx)); + } + childrenRowTypes.push_back(std::make_shared(std::move(names), std::move(types))); + } + + for (auto i = 1; i < childrenRowTypes.size(); ++i) { + if (!(childrenRowTypes[i]->equivalent(*childrenRowTypes[0]))) { + LOG_VALIDATION_MSG( + "All sources of the Set operation must have the same output type: " + childrenRowTypes[i]->toString() + + " vs. " + childrenRowTypes[0]->toString()); + return false; + } + } + return true; + } + default: + LOG_VALIDATION_MSG("Unsupported SetRel op: " + std::to_string(setRel.op())); + return false; + } +} + bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SortRel& sortRel) { if (sortRel.has_input() && !validate(sortRel.input())) { return false; @@ -769,8 +848,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SortRel& sortRel } const auto& extension = sortRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in SortRel."); return false; } @@ -822,8 +902,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ProjectRel& proj return false; } const auto& extension = projectRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in ProjectRel."); return false; } @@ -865,8 +946,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FilterRel& filte return false; } const auto& extension = filterRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in FilterRel."); return false; } @@ -938,8 +1020,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::JoinRel& joinRel } const auto& extension = joinRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in JoinRel."); return false; } @@ -991,8 +1074,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::CrossRel& crossR } const auto& extension = crossRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { logValidateMsg("Native validation failed due to: Validation failed for input types in CrossRel"); return false; } @@ -1070,11 +1154,13 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag // Validate input types. if (aggRel.has_advanced_extension()) { + TypePtr inputRowType; std::vector types; const auto& extension = aggRel.advanced_extension(); // Aggregate always has advanced extension for streaming aggregate optimization, // but only some of them have enhancement for validation. - if (extension.has_enhancement() && !validateInputTypes(extension, types)) { + if (extension.has_enhancement() && + (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types))) { LOG_VALIDATION_MSG("Validation failed for input types in AggregateRel."); return false; } @@ -1266,7 +1352,10 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Rel& rel) { return validate(rel.write()); } else if (rel.has_windowgrouplimit()) { return validate(rel.windowgrouplimit()); + } else if (rel.has_set()) { + return validate(rel.set()); } else { + LOG_VALIDATION_MSG("Unsupported relation type: " + rel.GetTypeName()); return false; } } diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h index 1fe174928fd9..0c8d882ca031 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h @@ -61,6 +61,9 @@ class SubstraitToVeloxPlanValidator { /// Used to validate whether the computing of this WindowGroupLimit is supported. bool validate(const ::substrait::WindowGroupLimitRel& windowGroupLimitRel); + /// Used to validate whether the computing of this Set is supported. + bool validate(const ::substrait::SetRel& setRel); + /// Used to validate whether the computing of this Aggregation is supported. bool validate(const ::substrait::AggregateRel& aggRel); @@ -103,9 +106,17 @@ class SubstraitToVeloxPlanValidator { std::vector validateLog_; - /// Used to get types from advanced extension and validate them. - bool validateInputTypes(const ::substrait::extensions::AdvancedExtension& extension, std::vector& types); + /// Used to get types from advanced extension and validate them, then convert to a Velox type that has arbitrary + /// levels of nesting. + bool parseVeloxType(const ::substrait::extensions::AdvancedExtension& extension, TypePtr& out); + + /// Flattens a Velox type with single level of nesting into a std::vector of child types. + bool flattenVeloxType1(const TypePtr& type, std::vector& out); + + /// Flattens a Velox type with two level of nesting into a dual-nested std::vector of child types. + bool flattenVeloxType2(const TypePtr& type, std::vector>& out); + /// Validate aggregate rel. bool validateAggRelFunctionType(const ::substrait::AggregateRel& substraitAgg); /// Validate the round scalar function. diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java index def1dca0a028..7d1931180847 100644 --- a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java @@ -27,6 +27,7 @@ import io.substrait.proto.CrossRel; import io.substrait.proto.JoinRel; +import io.substrait.proto.SetRel; import io.substrait.proto.SortField; import org.apache.spark.sql.catalyst.expressions.Attribute; @@ -317,4 +318,20 @@ public static RelNode makeGenerateRel( context.registerRelToOperator(operatorId); return new GenerateRelNode(input, generator, childOutput, extensionNode, outer); } + + public static RelNode makeSetRel( + List inputs, SetRel.SetOp setOp, SubstraitContext context, Long operatorId) { + context.registerRelToOperator(operatorId); + return new SetRelNode(inputs, setOp); + } + + public static RelNode makeSetRel( + List inputs, + SetRel.SetOp setOp, + AdvancedExtensionNode extensionNode, + SubstraitContext context, + Long operatorId) { + context.registerRelToOperator(operatorId); + return new SetRelNode(inputs, setOp, extensionNode); + } } diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java new file mode 100644 index 000000000000..ddcfb1701dd1 --- /dev/null +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.substrait.rel; + +import org.apache.gluten.substrait.extensions.AdvancedExtensionNode; + +import io.substrait.proto.Rel; +import io.substrait.proto.RelCommon; +import io.substrait.proto.SetRel; + +import java.io.Serializable; +import java.util.List; + +public class SetRelNode implements RelNode, Serializable { + private final List inputs; + private final SetRel.SetOp setOp; + private final AdvancedExtensionNode extensionNode; + + public SetRelNode(List inputs, SetRel.SetOp setOp, AdvancedExtensionNode extensionNode) { + this.inputs = inputs; + this.setOp = setOp; + this.extensionNode = extensionNode; + } + + public SetRelNode(List inputs, SetRel.SetOp setOp) { + this(inputs, setOp, null); + } + + @Override + public Rel toProtobuf() { + final RelCommon.Builder relCommonBuilder = RelCommon.newBuilder(); + relCommonBuilder.setDirect(RelCommon.Direct.newBuilder()); + final SetRel.Builder setBuilder = SetRel.newBuilder(); + setBuilder.setCommon(relCommonBuilder.build()); + if (inputs != null) { + for (RelNode input : inputs) { + setBuilder.addInputs(input.toProtobuf()); + } + } + setBuilder.setOp(setOp); + if (extensionNode != null) { + setBuilder.setAdvancedExtension(extensionNode.toProtobuf()); + } + final Rel.Builder builder = Rel.newBuilder(); + builder.setSet(setBuilder.build()); + return builder.build(); + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala index c67d4b5f8876..453cfab4e487 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala @@ -126,6 +126,10 @@ trait MetricsApi extends Serializable { def genSampleTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater + def genUnionTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] + + def genUnionTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater + def genColumnarInMemoryTableMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala index f9755605cab2..ac8e610956dc 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala @@ -261,10 +261,11 @@ abstract class ProjectExecTransformerBase(val list: Seq[NamedExpression], val in } } -// An alternatives for UnionExec. +// An alternative for UnionExec. case class ColumnarUnionExec(children: Seq[SparkPlan]) extends ValidatablePlan { children.foreach { case w: WholeStageTransformer => + // FIXME: Avoid such practice for plan immutability. w.setOutputSchemaForPlan(output) case _ => } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala new file mode 100644 index 000000000000..d27558746a40 --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.expression.ConverterUtils +import org.apache.gluten.extension.ValidationResult +import org.apache.gluten.metrics.MetricsUpdater +import org.apache.gluten.substrait.`type`.TypeBuilder +import org.apache.gluten.substrait.SubstraitContext +import org.apache.gluten.substrait.extensions.ExtensionBuilder +import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.{SparkPlan, UnionExec} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.vectorized.ColumnarBatch + +import io.substrait.proto.SetRel.SetOp + +import scala.collection.JavaConverters._ + +/** Transformer for UnionExec. Note: Spark's UnionExec represents a SQL UNION ALL. */ +case class UnionExecTransformer(children: Seq[SparkPlan]) extends TransformSupport { + private val union = UnionExec(children) + + // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks. + @transient override lazy val metrics: Map[String, SQLMetric] = + BackendsApiManager.getMetricsApiInstance.genUnionTransformerMetrics(sparkContext) + + override def output: Seq[Attribute] = union.output + + override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = children.flatMap(getColumnarInputRDDs) + + override def metricsUpdater(): MetricsUpdater = + BackendsApiManager.getMetricsApiInstance.genUnionTransformerMetricsUpdater(metrics) + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = + copy(children = newChildren) + + override protected def doValidateInternal(): ValidationResult = { + val context = new SubstraitContext + val operatorId = context.nextOperatorId(this.nodeName) + val relNode = getRelNode(context, operatorId, children.map(_.output), null, true) + doNativeValidation(context, relNode) + } + + override protected def doTransform(context: SubstraitContext): TransformContext = { + val childrenCtx = children.map(_.asInstanceOf[TransformSupport].transform(context)) + val operatorId = context.nextOperatorId(this.nodeName) + val relNode = + getRelNode(context, operatorId, children.map(_.output), childrenCtx.map(_.root), false) + TransformContext(output, relNode) + } + + private def getRelNode( + context: SubstraitContext, + operatorId: Long, + inputAttributes: Seq[Seq[Attribute]], + inputs: Seq[RelNode], + validation: Boolean): RelNode = { + if (validation) { + // Use the second level of nesting to represent N way inputs. + val inputTypeNodes = + inputAttributes.map( + attributes => + attributes.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)).asJava) + val extensionNode = ExtensionBuilder.makeAdvancedExtension( + BackendsApiManager.getTransformerApiInstance.packPBMessage( + TypeBuilder + .makeStruct( + false, + inputTypeNodes.map(nodes => TypeBuilder.makeStruct(false, nodes)).asJava) + .toProtobuf)) + return RelBuilder.makeSetRel( + inputs.asJava, + SetOp.SET_OP_UNION_ALL, + extensionNode, + context, + operatorId) + } + RelBuilder.makeSetRel(inputs.asJava, SetOp.SET_OP_UNION_ALL, context, operatorId) + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala new file mode 100644 index 000000000000..f0eea08018dd --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.columnar + +import org.apache.gluten.GlutenConfig +import org.apache.gluten.execution.{ColumnarUnionExec, UnionExecTransformer} + +import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + +/** + * Replace ColumnarUnionExec with UnionExecTransformer if possible. + * + * The rule is not included in [[org.apache.gluten.extension.columnar.heuristic.HeuristicTransform]] + * or [[org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform]] because it relies on + * children's output partitioning to be fully provided. + */ +case class UnionTransformerRule() extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + if (!GlutenConfig.getConf.enableNativeUnion) { + return plan + } + plan.transformUp { + case plan: ColumnarUnionExec => + val transformer = UnionExecTransformer(plan.children) + if (sameNumPartitions(plan.children) && validate(transformer)) { + transformer + } else { + plan + } + } + } + + private def sameNumPartitions(plans: Seq[SparkPlan]): Boolean = { + val partitioning = plans.map(_.outputPartitioning) + if (partitioning.exists(p => p.isInstanceOf[UnknownPartitioning])) { + return false + } + val numPartitions = plans.map(_.outputPartitioning.numPartitions) + numPartitions.forall(_ == numPartitions.head) + } + + private def validate(union: UnionExecTransformer): Boolean = { + union.doValidate().ok() + } +} diff --git a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala index fd250834d078..08081fadb5f9 100644 --- a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala @@ -176,25 +176,39 @@ abstract class WholeStageTransformerSuite result } + protected def compareResultsAgainstVanillaSpark( + sql: String, + compareResult: Boolean = true, + customCheck: DataFrame => Unit, + noFallBack: Boolean = true, + cache: Boolean = false): DataFrame = { + compareDfResultsAgainstVanillaSpark( + () => spark.sql(sql), + compareResult, + customCheck, + noFallBack, + cache) + } + /** * run a query with native engine as well as vanilla spark then compare the result set for * correctness check */ - protected def compareResultsAgainstVanillaSpark( - sqlStr: String, + protected def compareDfResultsAgainstVanillaSpark( + dataframe: () => DataFrame, compareResult: Boolean = true, customCheck: DataFrame => Unit, noFallBack: Boolean = true, cache: Boolean = false): DataFrame = { var expected: Seq[Row] = null withSQLConf(vanillaSparkConfs(): _*) { - val df = spark.sql(sqlStr) + val df = dataframe() expected = df.collect() } - // By default we will fallabck complex type scan but here we should allow + // By default, we will fallback complex type scan but here we should allow // to test support of complex type spark.conf.set("spark.gluten.sql.complexType.scan.fallback.enabled", "false"); - val df = spark.sql(sqlStr) + val df = dataframe() if (cache) { df.cache() } @@ -239,7 +253,12 @@ abstract class WholeStageTransformerSuite noFallBack: Boolean = true, cache: Boolean = false)(customCheck: DataFrame => Unit): DataFrame = { - compareResultsAgainstVanillaSpark(sqlStr, compareResult, customCheck, noFallBack, cache) + compareDfResultsAgainstVanillaSpark( + () => spark.sql(sqlStr), + compareResult, + customCheck, + noFallBack, + cache) } /** @@ -256,8 +275,8 @@ abstract class WholeStageTransformerSuite customCheck: DataFrame => Unit, noFallBack: Boolean = true, compareResult: Boolean = true): Unit = - compareResultsAgainstVanillaSpark( - tpchSQL(queryNum, tpchQueries), + compareDfResultsAgainstVanillaSpark( + () => spark.sql(tpchSQL(queryNum, tpchQueries)), compareResult, customCheck, noFallBack) diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index c4c67f49a59b..969fc072c21b 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -86,6 +86,8 @@ class GlutenConfig(conf: SQLConf) extends Logging { def enableColumnarUnion: Boolean = conf.getConf(COLUMNAR_UNION_ENABLED) + def enableNativeUnion: Boolean = conf.getConf(NATIVE_UNION_ENABLED) + def enableColumnarExpand: Boolean = conf.getConf(COLUMNAR_EXPAND_ENABLED) def enableColumnarBroadcastExchange: Boolean = conf.getConf(COLUMNAR_BROADCAST_EXCHANGE_ENABLED) @@ -1012,6 +1014,13 @@ object GlutenConfig { .booleanConf .createWithDefault(true) + val NATIVE_UNION_ENABLED = + buildConf("spark.gluten.sql.native.union") + .internal() + .doc("Enable or disable native union where computation is completely offloaded to backend.") + .booleanConf + .createWithDefault(false) + val COLUMNAR_EXPAND_ENABLED = buildConf("spark.gluten.sql.columnar.expand") .internal()