Skip to content

Commit

Permalink
support pass format to backend
Browse files Browse the repository at this point in the history
  • Loading branch information
baibaichen committed Jul 18, 2024
1 parent 261f89a commit c475114
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,10 @@ case class CartesianProductExecTransformer(
val (inputRightRelNode, inputRightOutput) =
(rightPlanContext.root, rightPlanContext.outputAttributes)

val expressionNode = condition.map {
expr =>
ExpressionConverter
.replaceWithExpressionTransformer(expr, inputLeftOutput ++ inputRightOutput)
.doTransform(context.registeredFunction)
}
val expressionNode =
condition.map {
SubstraitUtil.toSubstraitExpression(_, inputLeftOutput ++ inputRightOutput, context)
}

val extensionNode =
JoinUtils.createExtensionNode(inputLeftOutput ++ inputRightOutput, validation = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.expression.{AttributeReferenceTransformer, ConverterUtils, ExpressionConverter}
import org.apache.gluten.substrait.`type`.TypeBuilder
import org.apache.gluten.expression.{AttributeReferenceTransformer, ExpressionConverter}
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode}
import org.apache.gluten.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder}
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}
import org.apache.gluten.utils.SubstraitUtil

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans._
Expand All @@ -34,21 +33,11 @@ import io.substrait.proto.{CrossRel, JoinRel}
import scala.collection.JavaConverters._

object JoinUtils {
private def createEnhancement(output: Seq[Attribute]): com.google.protobuf.Any = {
val inputTypeNodes = output.map {
attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)
}
// Normally the enhancement node is only used for plan validation. But here the enhancement
// is also used in execution phase. In this case an empty typeUrlPrefix need to be passed,
// so that it can be correctly parsed into json string on the cpp side.
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodes.asJava).toProtobuf)
}

def createExtensionNode(output: Seq[Attribute], validation: Boolean): AdvancedExtensionNode = {
// Use field [enhancement] in a extension node for input type validation.
if (validation) {
ExtensionBuilder.makeAdvancedExtension(createEnhancement(output))
ExtensionBuilder.makeAdvancedExtension(SubstraitUtil.createEnhancement(output))
} else {
null
}
Expand All @@ -58,7 +47,7 @@ object JoinUtils {
!keyExprs.forall(_.isInstanceOf[AttributeReference])
}

def createPreProjectionIfNeeded(
private def createPreProjectionIfNeeded(
keyExprs: Seq[Expression],
inputNode: RelNode,
inputNodeOutput: Seq[Attribute],
Expand Down Expand Up @@ -131,17 +120,17 @@ object JoinUtils {
}
}

def createJoinExtensionNode(
private def createJoinExtensionNode(
joinParameters: Any,
output: Seq[Attribute]): AdvancedExtensionNode = {
// Use field [optimization] in a extension node
// to send some join parameters through Substrait plan.
val enhancement = createEnhancement(output)
val enhancement = SubstraitUtil.createEnhancement(output)
ExtensionBuilder.makeAdvancedExtension(joinParameters, enhancement)
}

// Return the direct join output.
protected def getDirectJoinOutput(
private def getDirectJoinOutput(
joinType: JoinType,
leftOutput: Seq[Attribute],
rightOutput: Seq[Attribute]): (Seq[Attribute], Seq[Attribute]) = {
Expand All @@ -164,7 +153,7 @@ object JoinUtils {
}
}

protected def getDirectJoinOutputSeq(
private def getDirectJoinOutputSeq(
joinType: JoinType,
leftOutput: Seq[Attribute],
rightOutput: Seq[Attribute]): Seq[Attribute] = {
Expand Down Expand Up @@ -209,8 +198,8 @@ object JoinUtils {
validation)

// Combine join keys to make a single expression.
val joinExpressionNode = (streamedKeys
.zip(buildKeys))
val joinExpressionNode = streamedKeys
.zip(buildKeys)
.map {
case ((leftKey, leftType), (rightKey, rightType)) =>
HashJoinLikeExecTransformer.makeEqualToExpression(
Expand All @@ -225,12 +214,10 @@ object JoinUtils {
HashJoinLikeExecTransformer.makeAndExpression(l, r, substraitContext.registeredFunction))

// Create post-join filter, which will be computed in hash join.
val postJoinFilter = condition.map {
expr =>
ExpressionConverter
.replaceWithExpressionTransformer(expr, streamedOutput ++ buildOutput)
.doTransform(substraitContext.registeredFunction)
}
val postJoinFilter =
condition.map {
SubstraitUtil.toSubstraitExpression(_, streamedOutput ++ buildOutput, substraitContext)
}

// Create JoinRel.
val joinRel = RelBuilder.makeJoinRel(
Expand Down Expand Up @@ -340,12 +327,14 @@ object JoinUtils {
joinParameters: Any,
validation: Boolean = false
): RelNode = {
val expressionNode = condition.map {
expr =>
ExpressionConverter
.replaceWithExpressionTransformer(expr, inputStreamedOutput ++ inputBuildOutput)
.doTransform(substraitContext.registeredFunction)
}
val expressionNode =
condition.map {
SubstraitUtil.toSubstraitExpression(
_,
inputStreamedOutput ++ inputBuildOutput,
substraitContext)
}

val extensionNode =
createJoinExtensionNode(joinParameters, inputStreamedOutput ++ inputBuildOutput)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,28 @@ import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.expression.ConverterUtils
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.`type`.{ColumnTypeNode, TypeBuilder}
import org.apache.gluten.substrait.`type`.ColumnTypeNode
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.extensions.ExtensionBuilder
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}
import org.apache.gluten.utils.SubstraitUtil

import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.MetadataBuilder

import com.google.protobuf.{Any, StringValue}
import org.apache.parquet.hadoop.ParquetOutputFormat

import java.util.Locale

import scala.collection.JavaConverters._
import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`

/**
Expand All @@ -56,7 +58,7 @@ case class WriteFilesExecTransformer(
staticPartitions: TablePartitionSpec)
extends UnaryTransformSupport {
// Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
@transient override lazy val metrics =
@transient override lazy val metrics: Map[String, SQLMetric] =
BackendsApiManager.getMetricsApiInstance.genWriteFilesTransformerMetrics(sparkContext)

override def metricsUpdater(): MetricsUpdater =
Expand All @@ -66,27 +68,25 @@ case class WriteFilesExecTransformer(

private val caseInsensitiveOptions = CaseInsensitiveMap(options)

def genWriteParameters(): Any = {
private def genWriteParameters(): Any = {
val fileFormatStr = fileFormat match {
case register: DataSourceRegister =>
register.shortName
case _ => "UnknownFileFormat"
}
val compressionCodec =
WriteFilesExecTransformer.getCompressionCodec(caseInsensitiveOptions).capitalize
val writeParametersStr = new StringBuffer("WriteParameters:")
writeParametersStr.append("is").append(compressionCodec).append("=1").append("\n")
writeParametersStr.append("is").append(compressionCodec).append("=1")
writeParametersStr.append(";format=").append(fileFormatStr).append("\n")

val message = StringValue
.newBuilder()
.setValue(writeParametersStr.toString)
.build()
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
}

def createEnhancement(output: Seq[Attribute]): com.google.protobuf.Any = {
val inputTypeNodes = output.map {
attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)
}

BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodes.asJava).toProtobuf)
}

def getRelNode(
context: SubstraitContext,
originalInputAttributes: Seq[Attribute],
Expand Down Expand Up @@ -118,10 +118,11 @@ case class WriteFilesExecTransformer(
val extensionNode = if (!validation) {
ExtensionBuilder.makeAdvancedExtension(
genWriteParameters(),
createEnhancement(originalInputAttributes))
SubstraitUtil.createEnhancement(originalInputAttributes))
} else {
// Use a extension node to send the input types through Substrait plan for validation.
ExtensionBuilder.makeAdvancedExtension(createEnhancement(originalInputAttributes))
ExtensionBuilder.makeAdvancedExtension(
SubstraitUtil.createEnhancement(originalInputAttributes))
}
RelBuilder.makeWriteRel(
input,
Expand All @@ -133,7 +134,7 @@ case class WriteFilesExecTransformer(
operatorId)
}

private def getFinalChildOutput(): Seq[Attribute] = {
private def getFinalChildOutput: Seq[Attribute] = {
val metadataExclusionList = conf
.getConf(GlutenConfig.NATIVE_WRITE_FILES_COLUMN_METADATA_EXCLUSION_LIST)
.split(",")
Expand All @@ -143,7 +144,7 @@ case class WriteFilesExecTransformer(
}

override protected def doValidateInternal(): ValidationResult = {
val finalChildOutput = getFinalChildOutput()
val finalChildOutput = getFinalChildOutput
val validationResult =
BackendsApiManager.getSettings.supportWriteFilesExec(
fileFormat,
Expand All @@ -165,7 +166,7 @@ case class WriteFilesExecTransformer(
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)
val currRel =
getRelNode(context, getFinalChildOutput(), operatorId, childCtx.root, validation = false)
getRelNode(context, getFinalChildOutput, operatorId, childCtx.root, validation = false)
assert(currRel != null, "Write Rel should be valid")
TransformContext(childCtx.outputAttributes, output, currRel)
}
Expand Down Expand Up @@ -196,7 +197,7 @@ object WriteFilesExecTransformer {
"__file_source_generated_metadata_col"
)

def removeMetadata(attr: Attribute, metadataExclusionList: Seq[String]): Attribute = {
private def removeMetadata(attr: Attribute, metadataExclusionList: Seq[String]): Attribute = {
val metadataKeys = INTERNAL_METADATA_KEYS ++ metadataExclusionList
attr.withMetadata {
var builder = new MetadataBuilder().withMetadata(attr.metadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@
*/
package org.apache.gluten.utils

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
import org.apache.gluten.substrait.`type`.TypeBuilder
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.ExpressionNode

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}

import io.substrait.proto.{CrossRel, JoinRel}

import scala.collection.JavaConverters._

object SubstraitUtil {
def toSubstrait(sparkJoin: JoinType): JoinRel.JoinType = sparkJoin match {
case _: InnerLike =>
Expand Down Expand Up @@ -55,4 +64,24 @@ object SubstraitUtil {
case _ =>
CrossRel.JoinType.UNRECOGNIZED
}

def createEnhancement(output: Seq[Attribute]): com.google.protobuf.Any = {
val inputTypeNodes = output.map {
attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)
}
// Normally the enhancement node is only used for plan validation. But here the enhancement
// is also used in execution phase. In this case an empty typeUrlPrefix need to be passed,
// so that it can be correctly parsed into json string on the cpp side.
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodes.asJava).toProtobuf)
}

def toSubstraitExpression(
expr: Expression,
attributeSeq: Seq[Attribute],
context: SubstraitContext): ExpressionNode = {
ExpressionConverter
.replaceWithExpressionTransformer(expr, attributeSeq)
.doTransform(context.registeredFunction)
}
}

0 comments on commit c475114

Please sign in to comment.