diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
index 73b2d0f21101..a0576a807b98 100644
--- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
+++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
@@ -450,6 +450,14 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
s"SampleTransformer metrics update is not supported in CH backend")
}
+ override def genUnionTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] =
+ throw new UnsupportedOperationException(
+ "UnionExecTransformer metrics update is not supported in CH backend")
+
+ override def genUnionTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater =
+ throw new UnsupportedOperationException(
+ "UnionExecTransformer metrics update is not supported in CH backend")
+
def genWriteFilesTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] =
Map(
"physicalWrittenBytes" -> SQLMetrics.createMetric(sparkContext, "number of written bytes"),
diff --git a/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java b/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java
index 9228a2f860ae..856ddf159730 100644
--- a/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java
+++ b/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java
@@ -31,7 +31,7 @@ public class GlutenURLDecoder {
*
Note: The World Wide Web Consortium
* Recommendation states that UTF-8 should be used. Not doing so may introduce
- * incompatibilites.
+ * incompatibilities.
*
* @param s the String
to decode
* @param enc The name of a supported character
diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
index d29d3029709e..3a82abe61833 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
@@ -31,6 +31,7 @@ import org.apache.spark.{HdfsConfGenerator, SparkConf, SparkContext}
import org.apache.spark.api.plugin.PluginContext
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
+import org.apache.spark.sql.execution.ColumnarCachedBatchSerializer
import org.apache.spark.sql.execution.datasources.GlutenWriterColumnarRules
import org.apache.spark.sql.execution.datasources.velox.{VeloxParquetWriterInjects, VeloxRowSplitter}
import org.apache.spark.sql.expression.UDFResolver
@@ -75,7 +76,7 @@ class VeloxListenerApi extends ListenerApi with Logging {
if (conf.getBoolean(GlutenConfig.COLUMNAR_TABLE_CACHE_ENABLED.key, defaultValue = false)) {
conf.set(
StaticSQLConf.SPARK_CACHE_SERIALIZER.key,
- "org.apache.spark.sql.execution.ColumnarCachedBatchSerializer")
+ classOf[ColumnarCachedBatchSerializer].getName)
}
// Static initializers for driver.
diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
index e70e1d13bdfe..934b680382ea 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
@@ -582,4 +582,15 @@ class VeloxMetricsApi extends MetricsApi with Logging {
override def genSampleTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater =
new SampleMetricsUpdater(metrics)
+
+ override def genUnionTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = Map(
+ "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"),
+ "inputVectors" -> SQLMetrics.createMetric(sparkContext, "number of input vectors"),
+ "inputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of input bytes"),
+ "wallNanos" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time of union"),
+ "cpuCount" -> SQLMetrics.createMetric(sparkContext, "cpu wall time count")
+ )
+
+ override def genUnionTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater =
+ new UnionMetricsUpdater(metrics)
}
diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
index 7841e6cd94b1..7337be573710 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
@@ -92,6 +92,7 @@ object VeloxRuleApi {
c => HeuristicTransform.Single(validatorBuilder(c.glutenConf), rewrites, offloads))
// Legacy: Post-transform rules.
+ injector.injectPostTransform(_ => UnionTransformerRule())
injector.injectPostTransform(c => PartialProjectRule.apply(c.session))
injector.injectPostTransform(_ => RemoveNativeWriteFilesSortAndProject())
injector.injectPostTransform(c => RewriteTransformer.apply(c.session))
@@ -178,6 +179,7 @@ object VeloxRuleApi {
// Gluten RAS: Post rules.
injector.injectPostTransform(_ => RemoveTransitions)
+ injector.injectPostTransform(_ => UnionTransformerRule())
injector.injectPostTransform(c => PartialProjectRule.apply(c.session))
injector.injectPostTransform(_ => RemoveNativeWriteFilesSortAndProject())
injector.injectPostTransform(c => RewriteTransformer.apply(c.session))
diff --git a/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala b/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala
index cd50d0b8e20c..b8ef1620f905 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala
@@ -58,7 +58,8 @@ object MetricsUtil extends Logging {
assert(t.children.size == 1, "MetricsUpdater.None can only be used on unary operator")
treeifyMetricsUpdaters(t.children.head)
case t: TransformSupport =>
- MetricsUpdaterTree(t.metricsUpdater(), t.children.map(treeifyMetricsUpdaters))
+ // Reversed children order to match the traversal code.
+ MetricsUpdaterTree(t.metricsUpdater(), t.children.reverse.map(treeifyMetricsUpdaters))
case _ =>
MetricsUpdaterTree(MetricsUpdater.Terminate, Seq())
}
@@ -233,6 +234,12 @@ object MetricsUtil extends Logging {
operatorMetrics,
metrics.getSingleMetrics,
joinParamsMap.get(operatorIdx))
+ case u: UnionMetricsUpdater =>
+ // JoinRel outputs two suites of metrics respectively for hash build and hash probe.
+ // Therefore, fetch one more suite of metrics here.
+ operatorMetrics.add(metrics.getOperatorMetrics(curMetricsIdx))
+ curMetricsIdx -= 1
+ u.updateUnionMetrics(operatorMetrics)
case hau: HashAggregateMetricsUpdater =>
hau.updateAggregationMetrics(operatorMetrics, aggParamsMap.get(operatorIdx))
case lu: LimitMetricsUpdater =>
diff --git a/backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala b/backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala
new file mode 100644
index 000000000000..9e91cf368c0a
--- /dev/null
+++ b/backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.metrics
+
+import org.apache.spark.sql.execution.metric.SQLMetric
+
+class UnionMetricsUpdater(val metrics: Map[String, SQLMetric]) extends MetricsUpdater {
+ override def updateNativeMetrics(opMetrics: IOperatorMetrics): Unit = {
+ throw new UnsupportedOperationException()
+ }
+
+ def updateUnionMetrics(unionMetrics: java.util.ArrayList[OperatorMetrics]): Unit = {
+ // Union was interpreted to LocalExchange + LocalPartition. Use metrics from LocalExchange.
+ val localExchangeMetrics = unionMetrics.get(0)
+ metrics("numInputRows") += localExchangeMetrics.inputRows
+ metrics("inputVectors") += localExchangeMetrics.inputVectors
+ metrics("inputBytes") += localExchangeMetrics.inputBytes
+ metrics("cpuCount") += localExchangeMetrics.cpuCount
+ metrics("wallNanos") += localExchangeMetrics.wallNanos
+ }
+}
diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
index 5cb2b652604d..8063a5d12207 100644
--- a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
+++ b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
@@ -537,11 +537,37 @@ class MiscOperatorSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa
|""".stripMargin) {
df =>
{
- getExecutedPlan(df).exists(plan => plan.find(_.isInstanceOf[ColumnarUnionExec]).isDefined)
+ assert(
+ getExecutedPlan(df).exists(
+ plan => plan.find(_.isInstanceOf[ColumnarUnionExec]).isDefined))
}
}
}
+ test("union_all two tables with known partitioning") {
+ withSQLConf(GlutenConfig.NATIVE_UNION_ENABLED.key -> "true") {
+ compareDfResultsAgainstVanillaSpark(
+ () => {
+ val df1 = spark.sql("select l_orderkey as orderkey from lineitem")
+ val df2 = spark.sql("select o_orderkey as orderkey from orders")
+ df1.repartition(5).union(df2.repartition(5))
+ },
+ compareResult = true,
+ checkGlutenOperatorMatch[UnionExecTransformer]
+ )
+
+ compareDfResultsAgainstVanillaSpark(
+ () => {
+ val df1 = spark.sql("select l_orderkey as orderkey from lineitem")
+ val df2 = spark.sql("select o_orderkey as orderkey from orders")
+ df1.repartition(5).union(df2.repartition(6))
+ },
+ compareResult = true,
+ checkGlutenOperatorMatch[ColumnarUnionExec]
+ )
+ }
+ }
+
test("union_all three tables") {
runQueryAndCompare("""
|select count(orderkey) from (
diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala
index 24e04f2dfce3..6ac59ba4fa6b 100644
--- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala
+++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala
@@ -255,7 +255,10 @@ class VeloxOrcDataTypeValidationSuite extends VeloxWholeStageTransformerSuite {
|""".stripMargin) {
df =>
{
- assert(getExecutedPlan(df).exists(plan => plan.isInstanceOf[ColumnarUnionExec]))
+ assert(
+ getExecutedPlan(df).exists(
+ plan =>
+ plan.isInstanceOf[ColumnarUnionExec] || plan.isInstanceOf[UnionExecTransformer]))
}
}
diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
index 57ca448fec79..cb5614f39669 100644
--- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
+++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
@@ -254,7 +254,10 @@ class VeloxParquetDataTypeValidationSuite extends VeloxWholeStageTransformerSuit
|""".stripMargin) {
df =>
{
- assert(getExecutedPlan(df).exists(plan => plan.isInstanceOf[ColumnarUnionExec]))
+ assert(
+ getExecutedPlan(df).exists(
+ plan =>
+ plan.isInstanceOf[ColumnarUnionExec] || plan.isInstanceOf[UnionExecTransformer]))
}
}
diff --git a/cpp/velox/compute/WholeStageResultIterator.cc b/cpp/velox/compute/WholeStageResultIterator.cc
index b6ecbd959f09..411c6c563646 100644
--- a/cpp/velox/compute/WholeStageResultIterator.cc
+++ b/cpp/velox/compute/WholeStageResultIterator.cc
@@ -91,7 +91,7 @@ WholeStageResultIterator::WholeStageResultIterator(
std::move(queryCtx),
velox::exec::Task::ExecutionMode::kSerial);
if (!task_->supportSerialExecutionMode()) {
- throw std::runtime_error("Task doesn't support single thread execution: " + planNode->toString());
+ throw std::runtime_error("Task doesn't support single threaded execution: " + planNode->toString());
}
auto fileSystem = velox::filesystems::getFileSystem(spillDir, nullptr);
GLUTEN_CHECK(fileSystem != nullptr, "File System for spilling is null!");
@@ -248,15 +248,47 @@ void WholeStageResultIterator::getOrderedNodeIds(
const std::shared_ptr& planNode,
std::vector& nodeIds) {
bool isProjectNode = (std::dynamic_pointer_cast(planNode) != nullptr);
+ bool isLocalExchangeNode = (std::dynamic_pointer_cast(planNode) != nullptr);
+ bool isUnionNode = isLocalExchangeNode &&
+ std::dynamic_pointer_cast(planNode)->type() ==
+ velox::core::LocalPartitionNode::Type::kGather;
const auto& sourceNodes = planNode->sources();
- for (const auto& sourceNode : sourceNodes) {
+ if (isProjectNode) {
+ GLUTEN_CHECK(sourceNodes.size() == 1, "Illegal state");
+ const auto sourceNode = sourceNodes.at(0);
// Filter over Project are mapped into FilterProject operator in Velox.
// Metrics are all applied on Project node, and the metrics for Filter node
// do not exist.
- if (isProjectNode && std::dynamic_pointer_cast(sourceNode)) {
+ if (std::dynamic_pointer_cast(sourceNode)) {
omittedNodeIds_.insert(sourceNode->id());
}
getOrderedNodeIds(sourceNode, nodeIds);
+ nodeIds.emplace_back(planNode->id());
+ return;
+ }
+
+ if (isUnionNode) {
+ // FIXME: The whole metrics system in gluten-substrait is magic. Passing metrics trees through JNI with a trivial
+ // array is possible but requires for a solid design. Apparently we haven't had it. All the code requires complete
+ // rework.
+ // Union was interpreted as LocalPartition + LocalExchange + 2 fake projects as children in Velox. So we only fetch
+ // metrics from the root node.
+ std::vector> unionChildren{};
+ for (const auto& source : planNode->sources()) {
+ const auto projectedChild = std::dynamic_pointer_cast(source);
+ GLUTEN_CHECK(projectedChild != nullptr, "Illegal state");
+ const auto projectSources = projectedChild->sources();
+ GLUTEN_CHECK(projectSources.size() == 1, "Illegal state");
+ const auto projectSource = projectSources.at(0);
+ getOrderedNodeIds(projectSource, nodeIds);
+ }
+ nodeIds.emplace_back(planNode->id());
+ return;
+ }
+
+ for (const auto& sourceNode : sourceNodes) {
+ // Post-order traversal.
+ getOrderedNodeIds(sourceNode, nodeIds);
}
nodeIds.emplace_back(planNode->id());
}
@@ -350,9 +382,9 @@ void WholeStageResultIterator::collectMetrics() {
continue;
}
- const auto& status = planStats.at(nodeId);
- // Add each operator status into metrics.
- for (const auto& entry : status.operatorStats) {
+ const auto& stats = planStats.at(nodeId);
+ // Add each operator stats into metrics.
+ for (const auto& entry : stats.operatorStats) {
const auto& second = entry.second;
metrics_->get(Metrics::kInputRows)[metricIndex] = second->inputRows;
metrics_->get(Metrics::kInputVectors)[metricIndex] = second->inputVectors;
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index 1efa7338796d..3ceccca4a3de 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -1043,6 +1043,50 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(
childNode);
}
+core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::SetRel& setRel) {
+ switch (setRel.op()) {
+ case ::substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_UNION_ALL: {
+ std::vector children;
+ for (int32_t i = 0; i < setRel.inputs_size(); ++i) {
+ const auto& input = setRel.inputs(i);
+ children.push_back(toVeloxPlan(input));
+ }
+ GLUTEN_CHECK(!children.empty(), "At least one source is required for Velox LocalPartition");
+
+ // Velox doesn't allow different field names in schemas of LocalPartitionNode's children.
+ // Add project nodes to unify the schemas.
+ const RowTypePtr outRowType = asRowType(children[0]->outputType());
+ std::vector outNames;
+ for (int32_t colIdx = 0; colIdx < outRowType->size(); ++colIdx) {
+ const auto name = outRowType->childAt(colIdx)->name();
+ outNames.push_back(name);
+ }
+
+ std::vector projectedChildren;
+ for (int32_t i = 0; i < children.size(); ++i) {
+ const auto& child = children[i];
+ const RowTypePtr& childRowType = child->outputType();
+ std::vector expressions;
+ for (int32_t colIdx = 0; colIdx < outNames.size(); ++colIdx) {
+ const auto fa =
+ std::make_shared(childRowType->childAt(colIdx), childRowType->nameOf(colIdx));
+ const auto cast = std::make_shared(outRowType->childAt(colIdx), fa, false);
+ expressions.push_back(cast);
+ }
+ auto project = std::make_shared(nextPlanNodeId(), outNames, expressions, child);
+ projectedChildren.push_back(project);
+ }
+ return std::make_shared(
+ nextPlanNodeId(),
+ core::LocalPartitionNode::Type::kGather,
+ std::make_shared(),
+ projectedChildren);
+ }
+ default:
+ throw GlutenException("Unsupported SetRel op: " + std::to_string(setRel.op()));
+ }
+}
+
core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::SortRel& sortRel) {
auto childNode = convertSingleInput<::substrait::SortRel>(sortRel);
auto [sortingKeys, sortingOrders] = processSortField(sortRel.sorts(), childNode->outputType());
@@ -1298,6 +1342,8 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
return toVeloxPlan(rel.write());
} else if (rel.has_windowgrouplimit()) {
return toVeloxPlan(rel.windowgrouplimit());
+ } else if (rel.has_set()) {
+ return toVeloxPlan(rel.set());
} else {
VELOX_NYI("Substrait conversion not supported for Rel.");
}
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.h b/cpp/velox/substrait/SubstraitToVeloxPlan.h
index 51e50ce34767..6121923df787 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.h
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.h
@@ -84,6 +84,9 @@ class SubstraitToVeloxPlanConverter {
/// Used to convert Substrait WindowGroupLimitRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::WindowGroupLimitRel& windowGroupLimitRel);
+ /// Used to convert Substrait SetRel into Velox PlanNode.
+ core::PlanNodePtr toVeloxPlan(const ::substrait::SetRel& setRel);
+
/// Used to convert Substrait JoinRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::JoinRel& joinRel);
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
index 3b74caf8ba5a..9325fed3217c 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
@@ -22,7 +22,6 @@
#include "TypeUtils.h"
#include "udf/UdfLoader.h"
#include "utils/Common.h"
-#include "velox/core/ExpressionEvaluator.h"
#include "velox/exec/Aggregate.h"
#include "velox/expression/Expr.h"
#include "velox/expression/SignatureBinder.h"
@@ -30,7 +29,7 @@
namespace gluten {
namespace {
-static const char* extractFileName(const char* file) {
+const char* extractFileName(const char* file) {
return strrchr(file, '/') ? strrchr(file, '/') + 1 : file;
}
@@ -53,13 +52,13 @@ static const char* extractFileName(const char* file) {
__FUNCTION__, \
reason))
-static const std::unordered_set kRegexFunctions = {
+const std::unordered_set kRegexFunctions = {
"regexp_extract",
"regexp_extract_all",
"regexp_replace",
"rlike"};
-static const std::unordered_set kBlackList = {
+const std::unordered_set kBlackList = {
"split_part",
"factorial",
"concat_ws",
@@ -70,32 +69,59 @@ static const std::unordered_set kBlackList = {
"approx_percentile",
"get_array_struct_fields",
"map_from_arrays"};
-
} // namespace
-bool SubstraitToVeloxPlanValidator::validateInputTypes(
+bool SubstraitToVeloxPlanValidator::parseVeloxType(
const ::substrait::extensions::AdvancedExtension& extension,
- std::vector& types) {
+ TypePtr& out) {
+ ::substrait::Type substraitType;
// The input type is wrapped in enhancement.
if (!extension.has_enhancement()) {
LOG_VALIDATION_MSG("Input type is not wrapped in enhancement.");
return false;
}
const auto& enhancement = extension.enhancement();
- ::substrait::Type inputType;
- if (!enhancement.UnpackTo(&inputType)) {
+ if (!enhancement.UnpackTo(&substraitType)) {
LOG_VALIDATION_MSG("Enhancement can't be unpacked to inputType.");
return false;
}
- if (!inputType.has_struct_()) {
- LOG_VALIDATION_MSG("Input type has no struct.");
+
+ out = SubstraitParser::parseType(substraitType);
+ return true;
+}
+
+bool SubstraitToVeloxPlanValidator::flattenVeloxType1(const TypePtr& type, std::vector& out) {
+ if (type->kind() != TypeKind::ROW) {
+ LOG_VALIDATION_MSG("Type is not a RowType.");
+ return false;
+ }
+ auto rowType = std::dynamic_pointer_cast(type);
+ if (!rowType) {
+ LOG_VALIDATION_MSG("Failed to cast to RowType.");
return false;
}
+ for (const auto& field : rowType->children()) {
+ out.emplace_back(field);
+ }
+ return true;
+}
- // Get the input types.
- const auto& sTypes = inputType.struct_().types();
- for (const auto& sType : sTypes) {
- types.emplace_back(SubstraitParser::parseType(sType));
+bool SubstraitToVeloxPlanValidator::flattenVeloxType2(const TypePtr& type, std::vector>& out) {
+ if (type->kind() != TypeKind::ROW) {
+ LOG_VALIDATION_MSG("Type is not a RowType.");
+ return false;
+ }
+ auto rowType = std::dynamic_pointer_cast(type);
+ if (!rowType) {
+ LOG_VALIDATION_MSG("Failed to cast to RowType.");
+ return false;
+ }
+ for (const auto& field : rowType->children()) {
+ std::vector inner;
+ if (!flattenVeloxType1(field, inner)) {
+ return false;
+ }
+ out.emplace_back(inner);
}
return true;
}
@@ -341,10 +367,11 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WriteRel& writeR
}
// Validate input data type.
+ TypePtr inputRowType;
std::vector types;
if (writeRel.has_named_table()) {
const auto& extension = writeRel.named_table().advanced_extension();
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input type validation in WriteRel.");
return false;
}
@@ -380,12 +407,12 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WriteRel& writeR
}
bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FetchRel& fetchRel) {
- RowTypePtr rowType = nullptr;
// Get and validate the input types from extension.
if (fetchRel.has_advanced_extension()) {
const auto& extension = fetchRel.advanced_extension();
+ TypePtr inputRowType;
std::vector types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Unsupported input types in FetchRel.");
return false;
}
@@ -396,7 +423,6 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FetchRel& fetchR
for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx));
}
- rowType = std::make_shared(std::move(names), std::move(types));
}
if (fetchRel.offset() < 0 || fetchRel.count() < 0) {
@@ -412,8 +438,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::TopNRel& topNRel
// Get and validate the input types from extension.
if (topNRel.has_advanced_extension()) {
const auto& extension = topNRel.advanced_extension();
+ TypePtr inputRowType;
std::vector types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Unsupported input types in TopNRel.");
return false;
}
@@ -457,8 +484,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::GenerateRel& gen
return false;
}
const auto& extension = generateRel.advanced_extension();
+ TypePtr inputRowType;
std::vector types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in GenerateRel.");
return false;
}
@@ -487,8 +515,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ExpandRel& expan
// Get and validate the input types from extension.
if (expandRel.has_advanced_extension()) {
const auto& extension = expandRel.advanced_extension();
+ TypePtr inputRowType;
std::vector types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Unsupported input types in ExpandRel.");
return false;
}
@@ -571,8 +600,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windo
return false;
}
const auto& extension = windowRel.advanced_extension();
+ TypePtr inputRowType;
std::vector types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in WindowRel.");
return false;
}
@@ -680,7 +710,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windo
LOG_VALIDATION_MSG("in windowRel, the sorting key in Sort Operator only support field.");
return false;
}
- exec::ExprSet exprSet({std::move(expression)}, execCtx_);
+ exec::ExprSet exprSet1({std::move(expression)}, execCtx_);
}
}
@@ -699,8 +729,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowGroupLimit
return false;
}
const auto& extension = windowGroupLimitRel.advanced_extension();
+ TypePtr inputRowType;
std::vector types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in WindowGroupLimitRel.");
return false;
}
@@ -750,13 +781,61 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowGroupLimit
LOG_VALIDATION_MSG("in windowGroupLimitRel, the sorting key in Sort Operator only support field.");
return false;
}
- exec::ExprSet exprSet({std::move(expression)}, execCtx_);
+ exec::ExprSet exprSet1({std::move(expression)}, execCtx_);
}
}
return true;
}
+bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SetRel& setRel) {
+ switch (setRel.op()) {
+ case ::substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_UNION_ALL: {
+ for (int32_t i = 0; i < setRel.inputs_size(); ++i) {
+ const auto& input = setRel.inputs(i);
+ if (!validate(input)) {
+ LOG_VALIDATION_MSG("ProjectRel input");
+ return false;
+ }
+ }
+ if (!setRel.has_advanced_extension()) {
+ LOG_VALIDATION_MSG("Input types are expected in SetRel.");
+ return false;
+ }
+ const auto& extension = setRel.advanced_extension();
+ TypePtr inputRowType;
+ std::vector> childrenTypes;
+ if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType2(inputRowType, childrenTypes)) {
+ LOG_VALIDATION_MSG("Validation failed for input types in SetRel.");
+ return false;
+ }
+ std::vector childrenRowTypes;
+ for (auto i = 0; i < childrenTypes.size(); ++i) {
+ auto& types = childrenTypes.at(i);
+ std::vector names;
+ names.reserve(types.size());
+ for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
+ names.emplace_back(SubstraitParser::makeNodeName(i, colIdx));
+ }
+ childrenRowTypes.push_back(std::make_shared(std::move(names), std::move(types)));
+ }
+
+ for (auto i = 1; i < childrenRowTypes.size(); ++i) {
+ if (!(childrenRowTypes[i]->equivalent(*childrenRowTypes[0]))) {
+ LOG_VALIDATION_MSG(
+ "All sources of the Set operation must have the same output type: " + childrenRowTypes[i]->toString() +
+ " vs. " + childrenRowTypes[0]->toString());
+ return false;
+ }
+ }
+ return true;
+ }
+ default:
+ LOG_VALIDATION_MSG("Unsupported SetRel op: " + std::to_string(setRel.op()));
+ return false;
+ }
+}
+
bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SortRel& sortRel) {
if (sortRel.has_input() && !validate(sortRel.input())) {
return false;
@@ -769,8 +848,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SortRel& sortRel
}
const auto& extension = sortRel.advanced_extension();
+ TypePtr inputRowType;
std::vector types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in SortRel.");
return false;
}
@@ -822,8 +902,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ProjectRel& proj
return false;
}
const auto& extension = projectRel.advanced_extension();
+ TypePtr inputRowType;
std::vector types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in ProjectRel.");
return false;
}
@@ -865,8 +946,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FilterRel& filte
return false;
}
const auto& extension = filterRel.advanced_extension();
+ TypePtr inputRowType;
std::vector types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in FilterRel.");
return false;
}
@@ -938,8 +1020,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::JoinRel& joinRel
}
const auto& extension = joinRel.advanced_extension();
+ TypePtr inputRowType;
std::vector types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in JoinRel.");
return false;
}
@@ -991,8 +1074,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::CrossRel& crossR
}
const auto& extension = crossRel.advanced_extension();
+ TypePtr inputRowType;
std::vector types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) {
logValidateMsg("Native validation failed due to: Validation failed for input types in CrossRel");
return false;
}
@@ -1070,11 +1154,13 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag
// Validate input types.
if (aggRel.has_advanced_extension()) {
+ TypePtr inputRowType;
std::vector types;
const auto& extension = aggRel.advanced_extension();
// Aggregate always has advanced extension for streaming aggregate optimization,
// but only some of them have enhancement for validation.
- if (extension.has_enhancement() && !validateInputTypes(extension, types)) {
+ if (extension.has_enhancement() &&
+ (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types))) {
LOG_VALIDATION_MSG("Validation failed for input types in AggregateRel.");
return false;
}
@@ -1266,7 +1352,10 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Rel& rel) {
return validate(rel.write());
} else if (rel.has_windowgrouplimit()) {
return validate(rel.windowgrouplimit());
+ } else if (rel.has_set()) {
+ return validate(rel.set());
} else {
+ LOG_VALIDATION_MSG("Unsupported relation type: " + rel.GetTypeName());
return false;
}
}
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
index 1fe174928fd9..0c8d882ca031 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
+++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
@@ -61,6 +61,9 @@ class SubstraitToVeloxPlanValidator {
/// Used to validate whether the computing of this WindowGroupLimit is supported.
bool validate(const ::substrait::WindowGroupLimitRel& windowGroupLimitRel);
+ /// Used to validate whether the computing of this Set is supported.
+ bool validate(const ::substrait::SetRel& setRel);
+
/// Used to validate whether the computing of this Aggregation is supported.
bool validate(const ::substrait::AggregateRel& aggRel);
@@ -103,9 +106,17 @@ class SubstraitToVeloxPlanValidator {
std::vector validateLog_;
- /// Used to get types from advanced extension and validate them.
- bool validateInputTypes(const ::substrait::extensions::AdvancedExtension& extension, std::vector& types);
+ /// Used to get types from advanced extension and validate them, then convert to a Velox type that has arbitrary
+ /// levels of nesting.
+ bool parseVeloxType(const ::substrait::extensions::AdvancedExtension& extension, TypePtr& out);
+
+ /// Flattens a Velox type with single level of nesting into a std::vector of child types.
+ bool flattenVeloxType1(const TypePtr& type, std::vector& out);
+
+ /// Flattens a Velox type with two level of nesting into a dual-nested std::vector of child types.
+ bool flattenVeloxType2(const TypePtr& type, std::vector>& out);
+ /// Validate aggregate rel.
bool validateAggRelFunctionType(const ::substrait::AggregateRel& substraitAgg);
/// Validate the round scalar function.
diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
index def1dca0a028..7d1931180847 100644
--- a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
+++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
@@ -27,6 +27,7 @@
import io.substrait.proto.CrossRel;
import io.substrait.proto.JoinRel;
+import io.substrait.proto.SetRel;
import io.substrait.proto.SortField;
import org.apache.spark.sql.catalyst.expressions.Attribute;
@@ -317,4 +318,20 @@ public static RelNode makeGenerateRel(
context.registerRelToOperator(operatorId);
return new GenerateRelNode(input, generator, childOutput, extensionNode, outer);
}
+
+ public static RelNode makeSetRel(
+ List inputs, SetRel.SetOp setOp, SubstraitContext context, Long operatorId) {
+ context.registerRelToOperator(operatorId);
+ return new SetRelNode(inputs, setOp);
+ }
+
+ public static RelNode makeSetRel(
+ List inputs,
+ SetRel.SetOp setOp,
+ AdvancedExtensionNode extensionNode,
+ SubstraitContext context,
+ Long operatorId) {
+ context.registerRelToOperator(operatorId);
+ return new SetRelNode(inputs, setOp, extensionNode);
+ }
}
diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java
new file mode 100644
index 000000000000..ddcfb1701dd1
--- /dev/null
+++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.substrait.rel;
+
+import org.apache.gluten.substrait.extensions.AdvancedExtensionNode;
+
+import io.substrait.proto.Rel;
+import io.substrait.proto.RelCommon;
+import io.substrait.proto.SetRel;
+
+import java.io.Serializable;
+import java.util.List;
+
+public class SetRelNode implements RelNode, Serializable {
+ private final List inputs;
+ private final SetRel.SetOp setOp;
+ private final AdvancedExtensionNode extensionNode;
+
+ public SetRelNode(List inputs, SetRel.SetOp setOp, AdvancedExtensionNode extensionNode) {
+ this.inputs = inputs;
+ this.setOp = setOp;
+ this.extensionNode = extensionNode;
+ }
+
+ public SetRelNode(List inputs, SetRel.SetOp setOp) {
+ this(inputs, setOp, null);
+ }
+
+ @Override
+ public Rel toProtobuf() {
+ final RelCommon.Builder relCommonBuilder = RelCommon.newBuilder();
+ relCommonBuilder.setDirect(RelCommon.Direct.newBuilder());
+ final SetRel.Builder setBuilder = SetRel.newBuilder();
+ setBuilder.setCommon(relCommonBuilder.build());
+ if (inputs != null) {
+ for (RelNode input : inputs) {
+ setBuilder.addInputs(input.toProtobuf());
+ }
+ }
+ setBuilder.setOp(setOp);
+ if (extensionNode != null) {
+ setBuilder.setAdvancedExtension(extensionNode.toProtobuf());
+ }
+ final Rel.Builder builder = Rel.newBuilder();
+ builder.setSet(setBuilder.build());
+ return builder.build();
+ }
+}
diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
index c67d4b5f8876..453cfab4e487 100644
--- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
@@ -126,6 +126,10 @@ trait MetricsApi extends Serializable {
def genSampleTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater
+ def genUnionTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric]
+
+ def genUnionTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater
+
def genColumnarInMemoryTableMetrics(sparkContext: SparkContext): Map[String, SQLMetric] =
Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
}
diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala
index f9755605cab2..ac8e610956dc 100644
--- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala
@@ -261,10 +261,11 @@ abstract class ProjectExecTransformerBase(val list: Seq[NamedExpression], val in
}
}
-// An alternatives for UnionExec.
+// An alternative for UnionExec.
case class ColumnarUnionExec(children: Seq[SparkPlan]) extends ValidatablePlan {
children.foreach {
case w: WholeStageTransformer =>
+ // FIXME: Avoid such practice for plan immutability.
w.setOutputSchemaForPlan(output)
case _ =>
}
diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala
new file mode 100644
index 000000000000..d27558746a40
--- /dev/null
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.expression.ConverterUtils
+import org.apache.gluten.extension.ValidationResult
+import org.apache.gluten.metrics.MetricsUpdater
+import org.apache.gluten.substrait.`type`.TypeBuilder
+import org.apache.gluten.substrait.SubstraitContext
+import org.apache.gluten.substrait.extensions.ExtensionBuilder
+import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution.{SparkPlan, UnionExec}
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+import io.substrait.proto.SetRel.SetOp
+
+import scala.collection.JavaConverters._
+
+/** Transformer for UnionExec. Note: Spark's UnionExec represents a SQL UNION ALL. */
+case class UnionExecTransformer(children: Seq[SparkPlan]) extends TransformSupport {
+ private val union = UnionExec(children)
+
+ // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
+ @transient override lazy val metrics: Map[String, SQLMetric] =
+ BackendsApiManager.getMetricsApiInstance.genUnionTransformerMetrics(sparkContext)
+
+ override def output: Seq[Attribute] = union.output
+
+ override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = children.flatMap(getColumnarInputRDDs)
+
+ override def metricsUpdater(): MetricsUpdater =
+ BackendsApiManager.getMetricsApiInstance.genUnionTransformerMetricsUpdater(metrics)
+
+ override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan =
+ copy(children = newChildren)
+
+ override protected def doValidateInternal(): ValidationResult = {
+ val context = new SubstraitContext
+ val operatorId = context.nextOperatorId(this.nodeName)
+ val relNode = getRelNode(context, operatorId, children.map(_.output), null, true)
+ doNativeValidation(context, relNode)
+ }
+
+ override protected def doTransform(context: SubstraitContext): TransformContext = {
+ val childrenCtx = children.map(_.asInstanceOf[TransformSupport].transform(context))
+ val operatorId = context.nextOperatorId(this.nodeName)
+ val relNode =
+ getRelNode(context, operatorId, children.map(_.output), childrenCtx.map(_.root), false)
+ TransformContext(output, relNode)
+ }
+
+ private def getRelNode(
+ context: SubstraitContext,
+ operatorId: Long,
+ inputAttributes: Seq[Seq[Attribute]],
+ inputs: Seq[RelNode],
+ validation: Boolean): RelNode = {
+ if (validation) {
+ // Use the second level of nesting to represent N way inputs.
+ val inputTypeNodes =
+ inputAttributes.map(
+ attributes =>
+ attributes.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)).asJava)
+ val extensionNode = ExtensionBuilder.makeAdvancedExtension(
+ BackendsApiManager.getTransformerApiInstance.packPBMessage(
+ TypeBuilder
+ .makeStruct(
+ false,
+ inputTypeNodes.map(nodes => TypeBuilder.makeStruct(false, nodes)).asJava)
+ .toProtobuf))
+ return RelBuilder.makeSetRel(
+ inputs.asJava,
+ SetOp.SET_OP_UNION_ALL,
+ extensionNode,
+ context,
+ operatorId)
+ }
+ RelBuilder.makeSetRel(inputs.asJava, SetOp.SET_OP_UNION_ALL, context, operatorId)
+ }
+}
diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala
new file mode 100644
index 000000000000..f0eea08018dd
--- /dev/null
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar
+
+import org.apache.gluten.GlutenConfig
+import org.apache.gluten.execution.{ColumnarUnionExec, UnionExecTransformer}
+
+import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * Replace ColumnarUnionExec with UnionExecTransformer if possible.
+ *
+ * The rule is not included in [[org.apache.gluten.extension.columnar.heuristic.HeuristicTransform]]
+ * or [[org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform]] because it relies on
+ * children's output partitioning to be fully provided.
+ */
+case class UnionTransformerRule() extends Rule[SparkPlan] {
+ override def apply(plan: SparkPlan): SparkPlan = {
+ if (!GlutenConfig.getConf.enableNativeUnion) {
+ return plan
+ }
+ plan.transformUp {
+ case plan: ColumnarUnionExec =>
+ val transformer = UnionExecTransformer(plan.children)
+ if (sameNumPartitions(plan.children) && validate(transformer)) {
+ transformer
+ } else {
+ plan
+ }
+ }
+ }
+
+ private def sameNumPartitions(plans: Seq[SparkPlan]): Boolean = {
+ val partitioning = plans.map(_.outputPartitioning)
+ if (partitioning.exists(p => p.isInstanceOf[UnknownPartitioning])) {
+ return false
+ }
+ val numPartitions = plans.map(_.outputPartitioning.numPartitions)
+ numPartitions.forall(_ == numPartitions.head)
+ }
+
+ private def validate(union: UnionExecTransformer): Boolean = {
+ union.doValidate().ok()
+ }
+}
diff --git a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
index fd250834d078..08081fadb5f9 100644
--- a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
+++ b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
@@ -176,25 +176,39 @@ abstract class WholeStageTransformerSuite
result
}
+ protected def compareResultsAgainstVanillaSpark(
+ sql: String,
+ compareResult: Boolean = true,
+ customCheck: DataFrame => Unit,
+ noFallBack: Boolean = true,
+ cache: Boolean = false): DataFrame = {
+ compareDfResultsAgainstVanillaSpark(
+ () => spark.sql(sql),
+ compareResult,
+ customCheck,
+ noFallBack,
+ cache)
+ }
+
/**
* run a query with native engine as well as vanilla spark then compare the result set for
* correctness check
*/
- protected def compareResultsAgainstVanillaSpark(
- sqlStr: String,
+ protected def compareDfResultsAgainstVanillaSpark(
+ dataframe: () => DataFrame,
compareResult: Boolean = true,
customCheck: DataFrame => Unit,
noFallBack: Boolean = true,
cache: Boolean = false): DataFrame = {
var expected: Seq[Row] = null
withSQLConf(vanillaSparkConfs(): _*) {
- val df = spark.sql(sqlStr)
+ val df = dataframe()
expected = df.collect()
}
- // By default we will fallabck complex type scan but here we should allow
+ // By default, we will fallback complex type scan but here we should allow
// to test support of complex type
spark.conf.set("spark.gluten.sql.complexType.scan.fallback.enabled", "false");
- val df = spark.sql(sqlStr)
+ val df = dataframe()
if (cache) {
df.cache()
}
@@ -239,7 +253,12 @@ abstract class WholeStageTransformerSuite
noFallBack: Boolean = true,
cache: Boolean = false)(customCheck: DataFrame => Unit): DataFrame = {
- compareResultsAgainstVanillaSpark(sqlStr, compareResult, customCheck, noFallBack, cache)
+ compareDfResultsAgainstVanillaSpark(
+ () => spark.sql(sqlStr),
+ compareResult,
+ customCheck,
+ noFallBack,
+ cache)
}
/**
@@ -256,8 +275,8 @@ abstract class WholeStageTransformerSuite
customCheck: DataFrame => Unit,
noFallBack: Boolean = true,
compareResult: Boolean = true): Unit =
- compareResultsAgainstVanillaSpark(
- tpchSQL(queryNum, tpchQueries),
+ compareDfResultsAgainstVanillaSpark(
+ () => spark.sql(tpchSQL(queryNum, tpchQueries)),
compareResult,
customCheck,
noFallBack)
diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
index c4c67f49a59b..969fc072c21b 100644
--- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
@@ -86,6 +86,8 @@ class GlutenConfig(conf: SQLConf) extends Logging {
def enableColumnarUnion: Boolean = conf.getConf(COLUMNAR_UNION_ENABLED)
+ def enableNativeUnion: Boolean = conf.getConf(NATIVE_UNION_ENABLED)
+
def enableColumnarExpand: Boolean = conf.getConf(COLUMNAR_EXPAND_ENABLED)
def enableColumnarBroadcastExchange: Boolean = conf.getConf(COLUMNAR_BROADCAST_EXCHANGE_ENABLED)
@@ -1012,6 +1014,13 @@ object GlutenConfig {
.booleanConf
.createWithDefault(true)
+ val NATIVE_UNION_ENABLED =
+ buildConf("spark.gluten.sql.native.union")
+ .internal()
+ .doc("Enable or disable native union where computation is completely offloaded to backend.")
+ .booleanConf
+ .createWithDefault(false)
+
val COLUMNAR_EXPAND_ENABLED =
buildConf("spark.gluten.sql.columnar.expand")
.internal()