From 0739f06fe24ef4fbb309dbae913fd81557ab09a5 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Fri, 23 Aug 2024 15:02:00 +0800 Subject: [PATCH 1/5] fixup --- .../clickhouse/CHListenerApi.scala | 12 +++-- .../clickhouse/CHSparkPlanExecApi.scala | 7 +-- .../CHHashAggregateExecTransformer.scala | 2 +- .../gluten/expression/CHExpressions.scala | 46 +++++++++++++++++++ .../extension/ExpressionExtensionTrait.scala | 13 ++++-- .../spark/sql/utils/ExpressionUtil.scala | 1 - .../org/apache/gluten/GlutenPlugin.scala | 12 +---- .../AggregateFunctionsBuilder.scala | 34 ++++---------- .../expression/ExpressionMappings.scala | 13 +----- .../org/apache/gluten/GlutenConfig.scala | 2 + 10 files changed, 85 insertions(+), 57 deletions(-) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala rename {gluten-core => backends-clickhouse}/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala (86%) rename {gluten-core => backends-clickhouse}/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala (99%) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala index 69797feb65fb..063803416aac 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala @@ -20,9 +20,8 @@ import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.ListenerApi import org.apache.gluten.execution.CHBroadcastBuildSideCache import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects, GlutenParquetWriterInjects, GlutenRowSplitter} -import org.apache.gluten.expression.UDFMappings +import org.apache.gluten.expression.{ExpressionMappings, UDFMappings} import org.apache.gluten.vectorized.{CHNativeExpressionEvaluator, JniLibLoader} - import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.api.plugin.PluginContext import org.apache.spark.internal.Logging @@ -31,8 +30,8 @@ import org.apache.spark.network.util.JavaUtils import org.apache.spark.rpc.{GlutenDriverEndpoint, GlutenExecutorEndpoint} import org.apache.spark.sql.execution.datasources.v1._ import org.apache.spark.util.SparkDirectoryUtil - import org.apache.commons.lang3.StringUtils +import org.apache.spark.sql.utils.ExpressionUtil import java.util.TimeZone @@ -42,6 +41,13 @@ class CHListenerApi extends ListenerApi with Logging { GlutenDriverEndpoint.glutenDriverEndpointRef = (new GlutenDriverEndpoint).self CHGlutenSQLAppStatusListener.registerListener(sc) initialize(pc.conf, isDriver = true) + + val expressionExtensionTransformer = ExpressionUtil.extendedExpressionTransformer( + pc.conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "") + ) + if (expressionExtensionTransformer != null) { + ExpressionMappings.expressionExtensionTransformer = expressionExtensionTransformer + } } override def onDriverShutdown(): Unit = shutdown() diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 6761269651c1..72994e9cc5e5 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -18,10 +18,10 @@ package org.apache.gluten.backendsapi.clickhouse import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi} -import org.apache.gluten.exception.GlutenException -import org.apache.gluten.exception.GlutenNotSupportException +import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException} import org.apache.gluten.execution._ import org.apache.gluten.expression._ +import org.apache.gluten.extension.ExpressionExtensionTrait import org.apache.gluten.extension.columnar.AddFallbackTagRule import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides import org.apache.gluten.extension.columnar.transition.Convention @@ -556,6 +556,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { Sig[CollectList](ExpressionNames.COLLECT_LIST), Sig[CollectSet](ExpressionNames.COLLECT_SET) ) ++ + ExpressionExtensionTrait.expressionExtensionTransformer.expressionSigList ++ SparkShimLoader.getSparkShims.bloomFilterExpressionMappings() } @@ -698,7 +699,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { .doTransform(args))) val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - AggregateFunctionsBuilder.create(args, aggExpression.aggregateFunction).toInt, + CHExpressions.createAggregateFunction(args, aggExpression.aggregateFunction).toInt, childrenNodeList, columnName, ConverterUtils.getTypeNode(aggExpression.dataType, aggExpression.nullable), diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala index 6c1fee39c423..06f9039abfa0 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala @@ -249,7 +249,7 @@ case class CHHashAggregateExecTransformer( childrenNodeList.add(node) } val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( - AggregateFunctionsBuilder.create(args, aggregateFunc), + CHExpressions.createAggregateFunction(args, aggregateFunc), childrenNodeList, modeToKeyWord(aggExpr.mode), ConverterUtils.getTypeNode(aggregateFunc.dataType, aggregateFunc.nullable) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala new file mode 100644 index 000000000000..983f21dd3a96 --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.gluten.expression + +import org.apache.gluten.expression.ConverterUtils.FunctionConfig +import org.apache.gluten.extension.ExpressionExtensionTrait +import org.apache.gluten.substrait.expression.ExpressionBuilder + +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction + +// Static helper object for handling expressions that are specifically used in CH backend. +object CHExpressions { + // Since https://github.com/apache/incubator-gluten/pull/1937. + def createAggregateFunction(args: java.lang.Object, aggregateFunc: AggregateFunction): Long = { + val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + if ( + ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping.contains( + aggregateFunc.getClass) + ) { + val (substraitAggFuncName, inputTypes) = + ExpressionExtensionTrait.expressionExtensionTransformer.buildCustomAggregateFunction( + aggregateFunc) + assert(substraitAggFuncName.isDefined) + return ExpressionBuilder.newScalarFunction( + functionMap, + ConverterUtils.makeFuncName(substraitAggFuncName.get, inputTypes, FunctionConfig.REQ)) + } + + AggregateFunctionsBuilder.create(args, aggregateFunc) + } +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala similarity index 86% rename from gluten-core/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala rename to backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala index 89bcb70641bd..9e0b0c05150a 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala @@ -63,8 +63,15 @@ trait ExpressionExtensionTrait { } } -case class DefaultExpressionExtensionTransformer() extends ExpressionExtensionTrait with Logging { +object ExpressionExtensionTrait { + var expressionExtensionTransformer: ExpressionExtensionTrait = + DefaultExpressionExtensionTransformer() - /** Generate the extension expressions list, format: Sig[XXXExpression]("XXXExpressionName") */ - override def expressionSigList: Seq[Sig] = Seq.empty[Sig] + private case class DefaultExpressionExtensionTransformer() + extends ExpressionExtensionTrait + with Logging { + + /** Generate the extension expressions list, format: Sig[XXXExpression]("XXXExpressionName") */ + override def expressionSigList: Seq[Sig] = Seq.empty[Sig] + } } diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala similarity index 99% rename from gluten-core/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala rename to backends-clickhouse/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala index b5c45e090f38..00d0f5ee862b 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.utils import org.apache.gluten.extension.{DefaultExpressionExtensionTransformer, ExpressionExtensionTrait} - import org.apache.spark.internal.Logging import org.apache.spark.util.Utils diff --git a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala index 6e3484dfa969..7a640651af3a 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala @@ -20,7 +20,6 @@ import org.apache.gluten.GlutenConfig.GLUTEN_DEFAULT_SESSION_TIMEZONE_KEY import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.events.GlutenBuildInfoEvent import org.apache.gluten.exception.GlutenException -import org.apache.gluten.expression.ExpressionMappings import org.apache.gluten.extension.GlutenSessionExtensions.{GLUTEN_SESSION_EXTENSION_NAME, SPARK_SESSION_EXTS_KEY} import org.apache.gluten.test.TestStats import org.apache.gluten.utils.TaskListener @@ -32,7 +31,6 @@ import org.apache.spark.listener.GlutenListenerFactory import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.execution.ui.GlutenEventUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.utils.ExpressionUtil import org.apache.spark.util.{SparkResourceUtil, TaskResources} import java.util @@ -73,14 +71,6 @@ private[gluten] class GlutenDriverPlugin extends DriverPlugin with Logging { BackendsApiManager.getListenerApiInstance.onDriverStart(sc, pluginContext) GlutenListenerFactory.addToSparkListenerBus(sc) - val expressionExtensionTransformer = ExpressionUtil.extendedExpressionTransformer( - conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "") - ) - - if (expressionExtensionTransformer != null) { - ExpressionMappings.expressionExtensionTransformer = expressionExtensionTransformer - } - Collections.emptyMap() } @@ -265,7 +255,7 @@ private[gluten] class GlutenDriverPlugin extends DriverPlugin with Logging { } private[gluten] class GlutenExecutorPlugin extends ExecutorPlugin { - private val taskListeners: Seq[TaskListener] = Array(TaskResources) + private val taskListeners: Seq[TaskListener] = Seq(TaskResources) /** Initialize the executor plugin. */ override def init(ctx: PluginContext, extraConf: util.Map[String, String]): Unit = { diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala index 6ac2c67eb086..bd73b7b7aa54 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala @@ -29,32 +29,18 @@ object AggregateFunctionsBuilder { val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] // First handle the custom aggregate functions - val (substraitAggFuncName, inputTypes) = - if ( - ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( - aggregateFunc.getClass) - ) { - val (substraitAggFuncName, inputTypes) = - ExpressionMappings.expressionExtensionTransformer.buildCustomAggregateFunction( - aggregateFunc) - assert(substraitAggFuncName.isDefined) - (substraitAggFuncName.get, inputTypes) - } else { - val substraitAggFuncName = getSubstraitFunctionName(aggregateFunc) + val substraitAggFuncName = getSubstraitFunctionName(aggregateFunc) - // Check whether each backend supports this aggregate function. - if ( - !BackendsApiManager.getValidatorApiInstance.doExprValidate( - substraitAggFuncName, - aggregateFunc) - ) { - throw new GlutenNotSupportException( - s"Aggregate function not supported for $aggregateFunc.") - } + // Check whether each backend supports this aggregate function. + if ( + !BackendsApiManager.getValidatorApiInstance.doExprValidate( + substraitAggFuncName, + aggregateFunc) + ) { + throw new GlutenNotSupportException(s"Aggregate function not supported for $aggregateFunc.") + } - val inputTypes: Seq[DataType] = aggregateFunc.children.map(child => child.dataType) - (substraitAggFuncName, inputTypes) - } + val inputTypes: Seq[DataType] = aggregateFunc.children.map(child => child.dataType) ExpressionBuilder.newScalarFunction( functionMap, diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index f2bb4a90621a..ea7127a40e6f 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -19,7 +19,6 @@ package org.apache.gluten.expression import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.expression.ExpressionNames._ -import org.apache.gluten.extension.{DefaultExpressionExtensionTransformer, ExpressionExtensionTrait} import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.sql.catalyst.expressions._ @@ -338,13 +337,8 @@ object ExpressionMappings { def expressionsMap: Map[Class[_], String] = { val blacklist = GlutenConfig.getConf.expressionBlacklist - val supportedExprs = defaultExpressionsMap ++ - expressionExtensionTransformer.extensionExpressionsMapping - if (blacklist.isEmpty) { - supportedExprs - } else { - supportedExprs.filterNot(kv => blacklist.contains(kv._2)) - } + val filtered = defaultExpressionsMap.filterNot(kv => blacklist.contains(kv._2)) + filtered } private lazy val defaultExpressionsMap: Map[Class[_], String] = { @@ -353,7 +347,4 @@ object ExpressionMappings { .map(s => (s.expClass, s.name)) .toMap[Class[_], String] } - - var expressionExtensionTransformer: ExpressionExtensionTrait = - DefaultExpressionExtensionTransformer() } diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index 9e5161fea472..10cf1958bf31 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -607,6 +607,7 @@ object GlutenConfig { val GLUTEN_SUPPORTED_PYTHON_UDFS = "spark.gluten.supported.python.udfs" val GLUTEN_SUPPORTED_SCALA_UDFS = "spark.gluten.supported.scala.udfs" + // FIXME: This only works with CH backend. val GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF = "spark.gluten.sql.columnar.extended.expressions.transformer" @@ -1672,6 +1673,7 @@ object GlutenConfig { .stringConf .createWithDefaultString("") + // FIXME: This only works with CH backend. val EXTENDED_EXPRESSION_TRAN_CONF = buildConf(GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF) .doc("A class for the extended expressions transformer.") From e24165af58b589732423b2afea6afa7a117708bb Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Fri, 23 Aug 2024 15:42:28 +0800 Subject: [PATCH 2/5] fixup --- .../clickhouse/CHListenerApi.scala | 9 +- .../clickhouse/CHSparkPlanExecApi.scala | 15 ++ .../CHHashAggregateExecTransformer.scala | 11 +- .../gluten/expression/CHExpressions.scala | 1 - .../extension/ExpressionExtensionTrait.scala | 4 +- .../spark/sql/utils/ExpressionUtil.scala | 4 +- .../gluten/backendsapi/SparkPlanExecApi.scala | 9 +- .../expression/ExpressionConverter.scala | 254 ++++++++---------- .../utils/velox/VeloxTestSettings.scala | 3 +- 9 files changed, 152 insertions(+), 158 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala index 063803416aac..60dc3dad0b87 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala @@ -20,8 +20,10 @@ import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.ListenerApi import org.apache.gluten.execution.CHBroadcastBuildSideCache import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects, GlutenParquetWriterInjects, GlutenRowSplitter} -import org.apache.gluten.expression.{ExpressionMappings, UDFMappings} +import org.apache.gluten.expression.UDFMappings +import org.apache.gluten.extension.ExpressionExtensionTrait import org.apache.gluten.vectorized.{CHNativeExpressionEvaluator, JniLibLoader} + import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.api.plugin.PluginContext import org.apache.spark.internal.Logging @@ -29,9 +31,10 @@ import org.apache.spark.listener.CHGlutenSQLAppStatusListener import org.apache.spark.network.util.JavaUtils import org.apache.spark.rpc.{GlutenDriverEndpoint, GlutenExecutorEndpoint} import org.apache.spark.sql.execution.datasources.v1._ +import org.apache.spark.sql.utils.ExpressionUtil import org.apache.spark.util.SparkDirectoryUtil + import org.apache.commons.lang3.StringUtils -import org.apache.spark.sql.utils.ExpressionUtil import java.util.TimeZone @@ -46,7 +49,7 @@ class CHListenerApi extends ListenerApi with Logging { pc.conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "") ) if (expressionExtensionTransformer != null) { - ExpressionMappings.expressionExtensionTransformer = expressionExtensionTransformer + ExpressionExtensionTrait.expressionExtensionTransformer = expressionExtensionTransformer } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 72994e9cc5e5..8431a3fe96f8 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -560,6 +560,21 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { SparkShimLoader.getSparkShims.bloomFilterExpressionMappings() } + /** Define backend-specific expression converter. */ + override def extraExpressionConverter( + substraitExprName: String, + expr: Expression, + attributeSeq: Seq[Attribute]): Option[ExpressionTransformer] = expr match { + case e + if ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping + .contains(e.getClass) => + // Use extended expression transformer to replace custom expression first + Some( + ExpressionExtensionTrait.expressionExtensionTransformer + .replaceWithExtensionExpressionTransformer(substraitExprName, e, attributeSeq)) + case _ => None + } + override def genStringTranslateTransformer( substraitExprName: String, srcExpr: ExpressionTransformer, diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala index 06f9039abfa0..d641c05cd62e 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala @@ -20,6 +20,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution.CHHashAggregateExecTransformer.getAggregateResultAttributes import org.apache.gluten.expression._ +import org.apache.gluten.extension.ExpressionExtensionTrait import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode} import org.apache.gluten.substrait.{AggregationParams, SubstraitContext} import org.apache.gluten.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode} @@ -286,10 +287,10 @@ case class CHHashAggregateExecTransformer( val aggregateFunc = aggExpr.aggregateFunction var aggFunctionName = if ( - ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( - aggregateFunc.getClass) + ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping + .contains(aggregateFunc.getClass) ) { - ExpressionMappings.expressionExtensionTransformer + ExpressionExtensionTrait.expressionExtensionTransformer .buildCustomAggregateFunction(aggregateFunc) ._1 .get @@ -437,10 +438,10 @@ case class CHHashAggregateExecPullOutHelper( val aggregateFunc = exp.aggregateFunction // First handle the custom aggregate functions if ( - ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( + ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping.contains( aggregateFunc.getClass) ) { - ExpressionMappings.expressionExtensionTransformer + ExpressionExtensionTrait.expressionExtensionTransformer .getAttrsIndexForExtensionAggregateExpr( aggregateFunc, exp.mode, diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala index 983f21dd3a96..af1ac52b1e40 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.gluten.expression import org.apache.gluten.expression.ConverterUtils.FunctionConfig diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala index 9e0b0c05150a..c64f26869eb6 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala @@ -67,9 +67,7 @@ object ExpressionExtensionTrait { var expressionExtensionTransformer: ExpressionExtensionTrait = DefaultExpressionExtensionTransformer() - private case class DefaultExpressionExtensionTransformer() - extends ExpressionExtensionTrait - with Logging { + case class DefaultExpressionExtensionTransformer() extends ExpressionExtensionTrait with Logging { /** Generate the extension expressions list, format: Sig[XXXExpression]("XXXExpressionName") */ override def expressionSigList: Seq[Sig] = Seq.empty[Sig] diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala index 00d0f5ee862b..852b34a099f2 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala @@ -16,7 +16,9 @@ */ package org.apache.spark.sql.utils -import org.apache.gluten.extension.{DefaultExpressionExtensionTransformer, ExpressionExtensionTrait} +import org.apache.gluten.extension.ExpressionExtensionTrait +import org.apache.gluten.extension.ExpressionExtensionTrait.DefaultExpressionExtensionTransformer + import org.apache.spark.internal.Logging import org.apache.spark.util.Utils diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index 0227ed5da127..fc0e063be6ab 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -448,9 +448,16 @@ trait SparkPlanExecApi { GenericExpressionTransformer(substraitExprName, exprs, original) } - /** Define backend specfic expression mappings. */ + /** Define backend-specific expression mappings. */ def extraExpressionMappings: Seq[Sig] = Seq.empty + /** Define backend-specific expression converter. */ + def extraExpressionConverter( + substraitExprName: String, + expr: Expression, + attributeSeq: Seq[Attribute]): Option[ExpressionTransformer] = + None + /** * Define whether the join operator is fallback because of the join operator is not supported by * backend diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 8bca5dbf8605..f4f91aed048c 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -43,16 +43,14 @@ object ExpressionConverter extends SQLConfHelper with Logging { exprs: Seq[Expression], attributeSeq: Seq[Attribute]): Seq[ExpressionTransformer] = { val expressionsMap = ExpressionMappings.expressionsMap - exprs.map { - expr => replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap) - } + exprs.map(expr => replaceWithExpressionTransformer0(expr, attributeSeq, expressionsMap)) } def replaceWithExpressionTransformer( expr: Expression, attributeSeq: Seq[Attribute]): ExpressionTransformer = { val expressionsMap = ExpressionMappings.expressionsMap - replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(expr, attributeSeq, expressionsMap) } private def replacePythonUDFWithExpressionTransformer( @@ -64,8 +62,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { case Some(name) => GenericExpressionTransformer( name, - udf.children.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + udf.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), udf) case _ => throw new GlutenNotSupportException(s"Not supported python udf: $udf.") @@ -84,8 +81,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { case Some(name) => GenericExpressionTransformer( name, - udf.children.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + udf.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), udf) case _ => throw new GlutenNotSupportException(s"Not supported scala udf: $udf.") @@ -108,13 +104,13 @@ object ExpressionConverter extends SQLConfHelper with Logging { ) val leftChild = - replaceWithExpressionTransformerInternal(left, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(left, attributeSeq, expressionsMap) val rightChild = - replaceWithExpressionTransformerInternal(right, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(right, attributeSeq, expressionsMap) DecimalArithmeticExpressionTransformer(substraitName, leftChild, rightChild, resultType, b) } - private def replaceWithExpressionTransformerInternal( + private def replaceWithExpressionTransformer0( expr: Expression, attributeSeq: Seq[Attribute], expressionsMap: Map[Class[_], String]): ExpressionTransformer = { @@ -137,14 +133,12 @@ object ExpressionConverter extends SQLConfHelper with Logging { case "decode" => return GenericExpressionTransformer( ExpressionNames.URL_DECODE, - child.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + child.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), i) case "encode" => return GenericExpressionTransformer( ExpressionNames.URL_ENCODE, - child.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + child.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), i) } } @@ -152,61 +146,61 @@ object ExpressionConverter extends SQLConfHelper with Logging { } val substraitExprName: String = getAndCheckSubstraitName(expr, expressionsMap) - + val backendConverted = BackendsApiManager.getSparkPlanExecApiInstance.extraExpressionConverter( + substraitExprName, + expr, + attributeSeq) + if (backendConverted.isDefined) { + return backendConverted.get + } expr match { - case extendedExpr - if ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( - extendedExpr.getClass) => - // Use extended expression transformer to replace custom expression first - ExpressionMappings.expressionExtensionTransformer - .replaceWithExtensionExpressionTransformer(substraitExprName, extendedExpr, attributeSeq) case c: CreateArray => val children = - c.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)) + c.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)) CreateArrayTransformer(substraitExprName, children, c) case g: GetArrayItem => BackendsApiManager.getSparkPlanExecApiInstance.genGetArrayItemTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(g.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(g.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(g.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(g.right, attributeSeq, expressionsMap), g ) case c: CreateMap => val children = - c.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)) + c.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)) CreateMapTransformer(substraitExprName, children, c) case g: GetMapValue => BackendsApiManager.getSparkPlanExecApiInstance.genGetMapValueTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(g.child, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(g.key, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(g.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(g.key, attributeSeq, expressionsMap), g ) case m: MapEntries => BackendsApiManager.getSparkPlanExecApiInstance.genMapEntriesTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(m.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(m.child, attributeSeq, expressionsMap), m) case e: Explode => ExplodeTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(e.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(e.child, attributeSeq, expressionsMap), e) case p: PosExplode => BackendsApiManager.getSparkPlanExecApiInstance.genPosExplodeTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(p.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(p.child, attributeSeq, expressionsMap), p, attributeSeq) case i: Inline => BackendsApiManager.getSparkPlanExecApiInstance.genInlineTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(i.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(i.child, attributeSeq, expressionsMap), i) case a: Alias => BackendsApiManager.getSparkPlanExecApiInstance.genAliasTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.child, attributeSeq, expressionsMap), a) case a: AttributeReference => if (attributeSeq == null) { @@ -231,14 +225,14 @@ object ExpressionConverter extends SQLConfHelper with Logging { case d: DateDiff => BackendsApiManager.getSparkPlanExecApiInstance.genDateDiffTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(d.endDate, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(d.startDate, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(d.endDate, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(d.startDate, attributeSeq, expressionsMap), d ) case r: Round if r.child.dataType.isInstanceOf[DecimalType] => DecimalRoundTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(r.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(r.child, attributeSeq, expressionsMap), r) case t: ToUnixTimestamp => // The failOnError depends on the config for ANSI. ANSI is not supported currently. @@ -246,8 +240,8 @@ object ExpressionConverter extends SQLConfHelper with Logging { GenericExpressionTransformer( substraitExprName, Seq( - replaceWithExpressionTransformerInternal(t.timeExp, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(t.format, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(t.timeExp, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(t.format, attributeSeq, expressionsMap) ), t ) @@ -255,33 +249,33 @@ object ExpressionConverter extends SQLConfHelper with Logging { GenericExpressionTransformer( substraitExprName, Seq( - replaceWithExpressionTransformerInternal(u.timeExp, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(u.format, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(u.timeExp, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(u.format, attributeSeq, expressionsMap) ), ToUnixTimestamp(u.timeExp, u.format, u.timeZoneId, u.failOnError) ) case t: TruncTimestamp => BackendsApiManager.getSparkPlanExecApiInstance.genTruncTimestampTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(t.format, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(t.timestamp, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(t.format, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(t.timestamp, attributeSeq, expressionsMap), t.timeZoneId, t ) case m: MonthsBetween => MonthsBetweenTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(m.date1, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(m.date2, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(m.roundOff, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(m.date1, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(m.date2, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(m.roundOff, attributeSeq, expressionsMap), m ) case i: If => IfTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(i.predicate, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(i.trueValue, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(i.falseValue, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(i.predicate, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(i.trueValue, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(i.falseValue, attributeSeq, expressionsMap), i ) case cw: CaseWhen => @@ -291,14 +285,14 @@ object ExpressionConverter extends SQLConfHelper with Logging { expr => { ( - replaceWithExpressionTransformerInternal(expr._1, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(expr._2, attributeSeq, expressionsMap)) + replaceWithExpressionTransformer0(expr._1, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(expr._2, attributeSeq, expressionsMap)) } }, cw.elseValue.map { expr => { - replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(expr, attributeSeq, expressionsMap) } }, cw @@ -310,12 +304,12 @@ object ExpressionConverter extends SQLConfHelper with Logging { } InTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(i.value, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(i.value, attributeSeq, expressionsMap), i) case i: InSet => InSetTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(i.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(i.child, attributeSeq, expressionsMap), i) case s: ScalarSubquery => ScalarSubqueryTransformer(substraitExprName, s) @@ -325,7 +319,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { BackendsApiManager.getSparkPlanExecApiInstance.genCastWithNewChild(c) CastTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(newCast.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(newCast.child, attributeSeq, expressionsMap), newCast) case s: String2TrimExpression => val (srcStr, trimStr) = s match { @@ -334,9 +328,9 @@ object ExpressionConverter extends SQLConfHelper with Logging { case StringTrimRight(srcStr, trimStr) => (srcStr, trimStr) } val children = trimStr - .map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)) + .map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)) .toSeq ++ - Seq(replaceWithExpressionTransformerInternal(srcStr, attributeSeq, expressionsMap)) + Seq(replaceWithExpressionTransformer0(srcStr, attributeSeq, expressionsMap)) GenericExpressionTransformer( substraitExprName, children, @@ -346,23 +340,20 @@ object ExpressionConverter extends SQLConfHelper with Logging { BackendsApiManager.getSparkPlanExecApiInstance.genHashExpressionTransformer( substraitExprName, m.children.map( - expr => replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap)), + expr => replaceWithExpressionTransformer0(expr, attributeSeq, expressionsMap)), m) case getStructField: GetStructField => // Different backends may have different result. BackendsApiManager.getSparkPlanExecApiInstance.genGetStructFieldTransformer( substraitExprName, - replaceWithExpressionTransformerInternal( - getStructField.child, - attributeSeq, - expressionsMap), + replaceWithExpressionTransformer0(getStructField.child, attributeSeq, expressionsMap), getStructField.ordinal, getStructField) case getArrayStructFields: GetArrayStructFields => GenericExpressionTransformer( substraitExprName, Seq( - replaceWithExpressionTransformerInternal( + replaceWithExpressionTransformer0( getArrayStructFields.child, attributeSeq, expressionsMap), @@ -372,26 +363,26 @@ object ExpressionConverter extends SQLConfHelper with Logging { case t: StringTranslate => BackendsApiManager.getSparkPlanExecApiInstance.genStringTranslateTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(t.srcExpr, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(t.matchingExpr, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(t.replaceExpr, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(t.srcExpr, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(t.matchingExpr, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(t.replaceExpr, attributeSeq, expressionsMap), t ) case r: RegExpReplace => BackendsApiManager.getSparkPlanExecApiInstance.genRegexpReplaceTransformer( substraitExprName, Seq( - replaceWithExpressionTransformerInternal(r.subject, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(r.regexp, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(r.rep, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(r.pos, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(r.subject, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(r.regexp, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(r.rep, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(r.pos, attributeSeq, expressionsMap) ), r ) case size: Size => // Covers Spark ArraySize which is replaced by Size(child, false). val child = - replaceWithExpressionTransformerInternal(size.child, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(size.child, attributeSeq, expressionsMap) GenericExpressionTransformer( substraitExprName, Seq(child, LiteralTransformer(size.legacySizeOfNull)), @@ -400,7 +391,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { BackendsApiManager.getSparkPlanExecApiInstance.genNamedStructTransformer( substraitExprName, namedStruct.children.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), namedStruct, attributeSeq) case namedLambdaVariable: NamedLambdaVariable => @@ -413,64 +404,57 @@ object ExpressionConverter extends SQLConfHelper with Logging { case lambdaFunction: LambdaFunction => LambdaFunctionTransformer( substraitExprName, - function = replaceWithExpressionTransformerInternal( + function = replaceWithExpressionTransformer0( lambdaFunction.function, attributeSeq, expressionsMap), arguments = lambdaFunction.arguments.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), original = lambdaFunction ) case j: JsonTuple => val children = - j.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)) + j.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)) JsonTupleExpressionTransformer(substraitExprName, children, j) case l: Like => BackendsApiManager.getSparkPlanExecApiInstance.genLikeTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(l.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(l.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(l.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(l.right, attributeSeq, expressionsMap), l ) case m: MakeDecimal => GenericExpressionTransformer( substraitExprName, Seq( - replaceWithExpressionTransformerInternal(m.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(m.child, attributeSeq, expressionsMap), LiteralTransformer(m.nullOnOverflow)), m ) case _: NormalizeNaNAndZero | _: PromotePrecision | _: TaggingExpression => ChildTransformer( substraitExprName, - replaceWithExpressionTransformerInternal( - expr.children.head, - attributeSeq, - expressionsMap), + replaceWithExpressionTransformer0(expr.children.head, attributeSeq, expressionsMap), expr ) case _: GetDateField | _: GetTimeField => ExtractDateTransformer( substraitExprName, - replaceWithExpressionTransformerInternal( - expr.children.head, - attributeSeq, - expressionsMap), + replaceWithExpressionTransformer0(expr.children.head, attributeSeq, expressionsMap), expr) case _: StringToMap => BackendsApiManager.getSparkPlanExecApiInstance.genStringToMapTransformer( substraitExprName, - expr.children.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + expr.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), expr) case CheckOverflow(b: BinaryArithmetic, decimalType, _) if !BackendsApiManager.getSettings.transformCheckOverflow && DecimalArithmeticUtil.isDecimalArithmetic(b) => DecimalArithmeticUtil.checkAllowDecimalArithmetic() val leftChild = - replaceWithExpressionTransformerInternal(b.left, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(b.left, attributeSeq, expressionsMap) val rightChild = - replaceWithExpressionTransformerInternal(b.right, attributeSeq, expressionsMap) + replaceWithExpressionTransformer0(b.right, attributeSeq, expressionsMap) DecimalArithmeticExpressionTransformer( getAndCheckSubstraitName(b, expressionsMap), leftChild, @@ -480,15 +464,14 @@ object ExpressionConverter extends SQLConfHelper with Logging { case c: CheckOverflow => CheckOverflowTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(c.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(c.child, attributeSeq, expressionsMap), c) case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) => DecimalArithmeticUtil.checkAllowDecimalArithmetic() if (!BackendsApiManager.getSettings.transformCheckOverflow) { GenericExpressionTransformer( substraitExprName, - expr.children.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + expr.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), expr ) } else { @@ -499,14 +482,14 @@ object ExpressionConverter extends SQLConfHelper with Logging { case n: NaNvl => BackendsApiManager.getSparkPlanExecApiInstance.genNaNvlTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(n.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(n.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(n.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(n.right, attributeSeq, expressionsMap), n ) case m: MakeTimestamp => BackendsApiManager.getSparkPlanExecApiInstance.genMakeTimestampTransformer( substraitExprName, - m.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + m.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), m) case timestampAdd if timestampAdd.getClass.getSimpleName.equals("TimestampAdd") => // for spark3.3 @@ -518,111 +501,99 @@ object ExpressionConverter extends SQLConfHelper with Logging { TimestampAddTransformer( substraitExprName, extract.get.head, - replaceWithExpressionTransformerInternal(add.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(add.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(add.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(add.right, attributeSeq, expressionsMap), extract.get.last, add ) case e: Transformable => val childrenTransformers = - e.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)) + e.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)) e.getTransformer(childrenTransformers) case u: Uuid => BackendsApiManager.getSparkPlanExecApiInstance.genUuidTransformer(substraitExprName, u) case f: ArrayFilter => BackendsApiManager.getSparkPlanExecApiInstance.genArrayFilterTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(f.argument, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(f.function, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(f.argument, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(f.function, attributeSeq, expressionsMap), f ) case arrayTransform: ArrayTransform => BackendsApiManager.getSparkPlanExecApiInstance.genArrayTransformTransformer( substraitExprName, - replaceWithExpressionTransformerInternal( - arrayTransform.argument, - attributeSeq, - expressionsMap), - replaceWithExpressionTransformerInternal( - arrayTransform.function, - attributeSeq, - expressionsMap), + replaceWithExpressionTransformer0(arrayTransform.argument, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(arrayTransform.function, attributeSeq, expressionsMap), arrayTransform ) case arraySort: ArraySort => BackendsApiManager.getSparkPlanExecApiInstance.genArraySortTransformer( substraitExprName, - replaceWithExpressionTransformerInternal( - arraySort.argument, - attributeSeq, - expressionsMap), - replaceWithExpressionTransformerInternal( - arraySort.function, - attributeSeq, - expressionsMap), + replaceWithExpressionTransformer0(arraySort.argument, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(arraySort.function, attributeSeq, expressionsMap), arraySort ) case tryEval @ TryEval(a: Add) => BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), tryEval, ExpressionNames.CHECKED_ADD ) case tryEval @ TryEval(a: Subtract) => BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), tryEval, ExpressionNames.CHECKED_SUBTRACT ) case tryEval @ TryEval(a: Divide) => BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), tryEval, ExpressionNames.CHECKED_DIVIDE ) case tryEval @ TryEval(a: Multiply) => BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), tryEval, ExpressionNames.CHECKED_MULTIPLY ) case a: Add => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), a, ExpressionNames.CHECKED_ADD ) case a: Subtract => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), a, ExpressionNames.CHECKED_SUBTRACT ) case a: Multiply => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), a, ExpressionNames.CHECKED_MULTIPLY ) case a: Divide => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.right, attributeSeq, expressionsMap), a, ExpressionNames.CHECKED_DIVIDE ) @@ -630,34 +601,34 @@ object ExpressionConverter extends SQLConfHelper with Logging { // This is a placeholder to handle try_eval(other expressions). BackendsApiManager.getSparkPlanExecApiInstance.genTryEvalTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(tryEval.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(tryEval.child, attributeSeq, expressionsMap), tryEval ) case a: ArrayForAll => BackendsApiManager.getSparkPlanExecApiInstance.genArrayForAllTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.argument, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.function, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.argument, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.function, attributeSeq, expressionsMap), a ) case a: ArrayExists => BackendsApiManager.getSparkPlanExecApiInstance.genArrayExistsTransformer( substraitExprName, - replaceWithExpressionTransformerInternal(a.argument, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal(a.function, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.argument, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(a.function, attributeSeq, expressionsMap), a ) case s: Shuffle => GenericExpressionTransformer( substraitExprName, Seq( - replaceWithExpressionTransformerInternal(s.child, attributeSeq, expressionsMap), + replaceWithExpressionTransformer0(s.child, attributeSeq, expressionsMap), LiteralTransformer(Literal(s.randomSeed.get))), s) case c: PreciseTimestampConversion => BackendsApiManager.getSparkPlanExecApiInstance.genPreciseTimestampConversionTransformer( substraitExprName, - Seq(replaceWithExpressionTransformerInternal(c.child, attributeSeq, expressionsMap)), + Seq(replaceWithExpressionTransformer0(c.child, attributeSeq, expressionsMap)), c ) case t: TransformKeys => @@ -672,7 +643,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { } GenericExpressionTransformer( substraitExprName, - t.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + t.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), t ) case e: EulerNumber => @@ -698,8 +669,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { case expr => GenericExpressionTransformer( substraitExprName, - expr.children.map( - replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + expr.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), expr ) } diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index c4799366dc96..e064f2afc9d7 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources.text.{GlutenTextV1Suite, Glute import org.apache.spark.sql.execution.datasources.v2.GlutenFileTableSuite import org.apache.spark.sql.execution.exchange.GlutenEnsureRequirementsSuite import org.apache.spark.sql.execution.joins.{GlutenBroadcastJoinSuite, GlutenExistenceJoinSuite, GlutenInnerJoinSuite, GlutenOuterJoinSuite} -import org.apache.spark.sql.extension.{GlutenCollapseProjectExecTransformerSuite, GlutenCustomerExpressionTransformerSuite, GlutenSessionExtensionSuite} +import org.apache.spark.sql.extension.{GlutenCollapseProjectExecTransformerSuite, GlutenSessionExtensionSuite} import org.apache.spark.sql.hive.execution.GlutenHiveSQLQuerySuite import org.apache.spark.sql.sources._ @@ -44,7 +44,6 @@ import org.apache.spark.sql.sources._ class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenSessionExtensionSuite] - enableSuite[GlutenCustomerExpressionTransformerSuite] enableSuite[GlutenDataFrameAggregateSuite] .exclude( From 8a954e17cbf2bc7511d43838232052008d052be6 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 26 Aug 2024 09:58:21 +0800 Subject: [PATCH 3/5] fixup fixup fixup --- .../sql/extension/CustomerExpressionTransformer.scala | 0 ...ClickhouseCustomerExpressionTransformerSuite.scala | 11 ++++++----- .../apache/gluten/expression/ExpressionMappings.scala | 11 ++++++++--- 3 files changed, 14 insertions(+), 8 deletions(-) rename {gluten-ut/spark32 => backends-clickhouse}/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala (100%) rename gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenCustomerExpressionTransformerSuite.scala => backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/GlutenClickhouseCustomerExpressionTransformerSuite.scala (91%) diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala similarity index 100% rename from gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala rename to backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenCustomerExpressionTransformerSuite.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/GlutenClickhouseCustomerExpressionTransformerSuite.scala similarity index 91% rename from gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenCustomerExpressionTransformerSuite.scala rename to backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/GlutenClickhouseCustomerExpressionTransformerSuite.scala index 91344f8778ca..87f0953f3b87 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenCustomerExpressionTransformerSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/GlutenClickhouseCustomerExpressionTransformerSuite.scala @@ -16,17 +16,17 @@ */ package org.apache.spark.sql.extension -import org.apache.gluten.execution.ProjectExecTransformer +import org.apache.gluten.execution.{GlutenClickHouseWholeStageTransformerSuite, ProjectExecTransformer} import org.apache.gluten.expression.ExpressionConverter import org.apache.spark.SparkConf -import org.apache.spark.sql.{GlutenSQLTestsTrait, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{IntervalUtils, TypeUtils} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{AbstractDataType, CalendarIntervalType, DayTimeIntervalType, TypeCollection, YearMonthIntervalType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval case class CustomAdd( @@ -71,7 +71,8 @@ case class CustomAdd( ): CustomAdd = copy(left = newLeft, right = newRight) } -class GlutenCustomerExpressionTransformerSuite extends GlutenSQLTestsTrait { +class GlutenClickhouseCustomerExpressionTransformerSuite + extends GlutenClickHouseWholeStageTransformerSuite { override def sparkConf: SparkConf = { super.sparkConf @@ -92,7 +93,7 @@ class GlutenCustomerExpressionTransformerSuite extends GlutenSQLTestsTrait { ) } - testGluten("test custom expression transformer") { + test("test custom expression transformer") { spark .createDataFrame(Seq((1, 1.1), (2, 2.2))) .createOrReplaceTempView("custom_table") diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index ea7127a40e6f..38f9de629a16 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -337,13 +337,18 @@ object ExpressionMappings { def expressionsMap: Map[Class[_], String] = { val blacklist = GlutenConfig.getConf.expressionBlacklist - val filtered = defaultExpressionsMap.filterNot(kv => blacklist.contains(kv._2)) + val filtered = (defaultExpressionsMap ++ toMap( + BackendsApiManager.getSparkPlanExecApiInstance.extraExpressionMappings)).filterNot( + kv => blacklist.contains(kv._2)) filtered } private lazy val defaultExpressionsMap: Map[Class[_], String] = { - (SCALAR_SIGS ++ AGGREGATE_SIGS ++ WINDOW_SIGS ++ - BackendsApiManager.getSparkPlanExecApiInstance.extraExpressionMappings) + toMap(SCALAR_SIGS ++ AGGREGATE_SIGS ++ WINDOW_SIGS) + } + + private def toMap(sigs: Seq[Sig]): Map[Class[_], String] = { + sigs .map(s => (s.expClass, s.name)) .toMap[Class[_], String] } From 264a7e9717504567a1ae9e0013bf364ea26b4244 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 26 Aug 2024 10:30:21 +0800 Subject: [PATCH 4/5] fixup --- .../gluten/utils/clickhouse/ClickHouseTestSettings.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 5c2833de4bc0..9b2e2ab95bc9 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources.text.{GlutenTextV1Suite, Glute import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.execution.exchange.GlutenEnsureRequirementsSuite import org.apache.spark.sql.execution.joins.{GlutenExistenceJoinSuite, GlutenInnerJoinSuite, GlutenOuterJoinSuite} -import org.apache.spark.sql.extension.{GlutenCustomerExpressionTransformerSuite, GlutenCustomerExtensionSuite, GlutenSessionExtensionSuite} +import org.apache.spark.sql.extension.{GlutenCustomerExtensionSuite, GlutenSessionExtensionSuite} import org.apache.spark.sql.hive.execution.GlutenHiveSQLQueryCHSuite import org.apache.spark.sql.sources._ import org.apache.spark.sql.statistics.SparkFunctionStatistics @@ -2133,7 +2133,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("right outer join with unique keys using ShuffledHashJoin (whole-stage-codegen on)") .exclude("right outer join with unique keys using SortMergeJoin (whole-stage-codegen off)") .exclude("right outer join with unique keys using SortMergeJoin (whole-stage-codegen on)") - enableSuite[GlutenCustomerExpressionTransformerSuite] enableSuite[GlutenCustomerExtensionSuite] enableSuite[GlutenSessionExtensionSuite] enableSuite[GlutenBucketedReadWithoutHiveSupportSuite] From a5d5ea61452de1d24c1ecf9e55b5168a6848f71a Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 26 Aug 2024 12:17:41 +0800 Subject: [PATCH 5/5] fixup --- ...seCustomerExpressionTransformerSuite.scala | 13 +++++-- .../sql/catalyst/expressions/EvalMode.scala | 36 +++++++++++++++++++ .../sql/catalyst/expressions/EvalMode.scala | 36 +++++++++++++++++++ 3 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala create mode 100644 shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/GlutenClickhouseCustomerExpressionTransformerSuite.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/GlutenClickhouseCustomerExpressionTransformerSuite.scala index 87f0953f3b87..cd8bf579fa66 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/GlutenClickhouseCustomerExpressionTransformerSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/GlutenClickhouseCustomerExpressionTransformerSuite.scala @@ -32,8 +32,9 @@ import org.apache.spark.unsafe.types.CalendarInterval case class CustomAdd( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) - extends BinaryArithmetic { + override val failOnError: Boolean = SQLConf.get.ansiEnabled) + extends BinaryArithmetic + with CustomAdd.Compatibility { def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) @@ -69,6 +70,14 @@ case class CustomAdd( newLeft: Expression, newRight: Expression ): CustomAdd = copy(left = newLeft, right = newRight) + + override protected val evalMode: EvalMode.Value = EvalMode.LEGACY +} + +object CustomAdd { + trait Compatibility { + protected val evalMode: EvalMode.Value + } } class GlutenClickhouseCustomerExpressionTransformerSuite diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala new file mode 100644 index 000000000000..0a3c63ccd8b9 --- /dev/null +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.internal.SQLConf + +/** For compatibility with Spark version <= 3.3. The class was added in vanilla Spark since 3.4. */ +object EvalMode extends Enumeration { + val LEGACY, ANSI, TRY = Value + + def fromSQLConf(conf: SQLConf): Value = if (conf.ansiEnabled) { + ANSI + } else { + LEGACY + } + + def fromBoolean(ansiEnabled: Boolean): Value = if (ansiEnabled) { + ANSI + } else { + LEGACY + } +} diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala new file mode 100644 index 000000000000..0a3c63ccd8b9 --- /dev/null +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.internal.SQLConf + +/** For compatibility with Spark version <= 3.3. The class was added in vanilla Spark since 3.4. */ +object EvalMode extends Enumeration { + val LEGACY, ANSI, TRY = Value + + def fromSQLConf(conf: SQLConf): Value = if (conf.ansiEnabled) { + ANSI + } else { + LEGACY + } + + def fromBoolean(ansiEnabled: Boolean): Value = if (ansiEnabled) { + ANSI + } else { + LEGACY + } +}