Skip to content

Commit

Permalink
[VL] Allow udf type conversion (#6660)
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma authored Aug 6, 2024
1 parent 0625a75 commit 43d0ff9
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ object VeloxBackendSettings extends BackendSettingsApi {
val GLUTEN_VELOX_UDF_LIB_PATHS = getBackendConfigPrefix() + ".udfLibraryPaths"
val GLUTEN_VELOX_DRIVER_UDF_LIB_PATHS = getBackendConfigPrefix() + ".driver.udfLibraryPaths"
val GLUTEN_VELOX_INTERNAL_UDF_LIB_PATHS = getBackendConfigPrefix() + ".internal.udfLibraryPaths"
val GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION = getBackendConfigPrefix() + ".udfAllowTypeConversion"

val MAXIMUM_BATCH_SIZE: Int = 32768

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, ExpressionInfo, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -74,18 +75,21 @@ trait UDFSignatureBase {
val expressionType: ExpressionType
val children: Seq[DataType]
val variableArity: Boolean
val allowTypeConversion: Boolean
}

case class UDFSignature(
expressionType: ExpressionType,
children: Seq[DataType],
variableArity: Boolean)
variableArity: Boolean,
allowTypeConversion: Boolean)
extends UDFSignatureBase

case class UDAFSignature(
expressionType: ExpressionType,
children: Seq[DataType],
variableArity: Boolean,
allowTypeConversion: Boolean,
intermediateAttrs: Seq[AttributeReference])
extends UDFSignatureBase

Expand Down Expand Up @@ -130,26 +134,30 @@ object UDFResolver extends Logging {
name: String,
returnType: Array[Byte],
argTypes: Array[Byte],
variableArity: Boolean): Unit = {
variableArity: Boolean,
allowTypeConversion: Boolean): Unit = {
registerUDF(
name,
ConverterUtils.parseFromBytes(returnType),
ConverterUtils.parseFromBytes(argTypes),
variableArity)
variableArity,
allowTypeConversion)
}

private def registerUDF(
name: String,
returnType: ExpressionType,
argTypes: ExpressionType,
variableArity: Boolean): Unit = {
variableArity: Boolean,
allowTypeConversion: Boolean): Unit = {
assert(argTypes.dataType.isInstanceOf[StructType])
val v =
UDFMap.getOrElseUpdate(name, mutable.MutableList[UDFSignature]())
v += UDFSignature(
returnType,
argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType),
variableArity)
variableArity,
allowTypeConversion)
UDFNames += name
logInfo(s"Registered UDF: $name($argTypes) -> $returnType")
}
Expand All @@ -159,13 +167,15 @@ object UDFResolver extends Logging {
returnType: Array[Byte],
argTypes: Array[Byte],
intermediateTypes: Array[Byte],
variableArity: Boolean): Unit = {
variableArity: Boolean,
enableTypeConversion: Boolean): Unit = {
registerUDAF(
name,
ConverterUtils.parseFromBytes(returnType),
ConverterUtils.parseFromBytes(argTypes),
ConverterUtils.parseFromBytes(intermediateTypes),
variableArity
variableArity,
enableTypeConversion
)
}

Expand All @@ -174,7 +184,8 @@ object UDFResolver extends Logging {
returnType: ExpressionType,
argTypes: ExpressionType,
intermediateTypes: ExpressionType,
variableArity: Boolean): Unit = {
variableArity: Boolean,
allowTypeConversion: Boolean): Unit = {
assert(argTypes.dataType.isInstanceOf[StructType])

val aggBufferAttributes: Seq[AttributeReference] =
Expand All @@ -194,6 +205,7 @@ object UDFResolver extends Logging {
returnType,
argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType),
variableArity,
allowTypeConversion,
aggBufferAttributes)
UDAFNames += name
logInfo(s"Registered UDAF: $name($argTypes) -> $returnType")
Expand Down Expand Up @@ -346,16 +358,27 @@ object UDFResolver extends Logging {
}
}

private def checkAllowTypeConversion: Boolean = {
SQLConf.get
.getConfString(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION, "false")
.toBoolean
}

private def getUdfExpression(name: String)(children: Seq[Expression]) = {
def errorMessage: String =
s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} is not registered."

val allowTypeConversion = checkAllowTypeConversion
val signatures =
UDFMap.getOrElse(name, throw new UnsupportedOperationException(errorMessage));

signatures.find(sig => tryBind(sig, children.map(_.dataType))) match {
signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match {
case Some(sig) =>
UDFExpression(name, sig.expressionType.dataType, sig.expressionType.nullable, children)
UDFExpression(
name,
sig.expressionType.dataType,
sig.expressionType.nullable,
if (!allowTypeConversion && !sig.allowTypeConversion) children
else applyCast(children, sig))
case None =>
throw new UnsupportedOperationException(errorMessage)
}
Expand All @@ -365,62 +388,116 @@ object UDFResolver extends Logging {
def errorMessage: String =
s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} is not registered."

val allowTypeConversion = checkAllowTypeConversion
val signatures =
UDAFMap.getOrElse(
name,
throw new UnsupportedOperationException(errorMessage)
)

signatures.find(sig => tryBind(sig, children.map(_.dataType))) match {
signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match {
case Some(sig) =>
UserDefinedAggregateFunction(
name,
sig.expressionType.dataType,
sig.expressionType.nullable,
children,
sig.intermediateAttrs)
if (!allowTypeConversion && !sig.allowTypeConversion) children
else applyCast(children, sig),
sig.intermediateAttrs
)
case None =>
throw new UnsupportedOperationException(errorMessage)
}
}

private def tryBind(
sig: UDFSignatureBase,
requiredDataTypes: Seq[DataType],
allowTypeConversion: Boolean): Boolean = {
if (
!tryBindStrict(sig, requiredDataTypes) && (allowTypeConversion || sig.allowTypeConversion)
) {
tryBindWithTypeConversion(sig, requiredDataTypes)
} else {
true
}
}

// Returns true if required data types match the function signature.
// If the function signature is variable arity, the number of the last argument can be zero
// or more.
private def tryBind(sig: UDFSignatureBase, requiredDataTypes: Seq[DataType]): Boolean = {
private def tryBindWithTypeConversion(
sig: UDFSignatureBase,
requiredDataTypes: Seq[DataType]): Boolean = {
tryBind0(sig, requiredDataTypes, Cast.canCast)
}

private def tryBindStrict(sig: UDFSignatureBase, requiredDataTypes: Seq[DataType]): Boolean = {
tryBind0(sig, requiredDataTypes, DataTypeUtils.sameType)
}

private def tryBind0(
sig: UDFSignatureBase,
requiredDataTypes: Seq[DataType],
checkType: (DataType, DataType) => Boolean): Boolean = {
if (!sig.variableArity) {
sig.children.size == requiredDataTypes.size &&
sig.children
.zip(requiredDataTypes)
.forall { case (candidate, required) => DataTypeUtils.sameType(candidate, required) }
requiredDataTypes
.zip(sig.children)
.forall { case (required, candidate) => checkType(required, candidate) }
} else {
// If variableArity is true, there must be at least one argument in the signature.
if (requiredDataTypes.size < sig.children.size - 1) {
false
} else if (requiredDataTypes.size == sig.children.size - 1) {
sig.children
.dropRight(1)
.zip(requiredDataTypes)
.forall { case (candidate, required) => DataTypeUtils.sameType(candidate, required) }
requiredDataTypes
.zip(sig.children.dropRight(1))
.forall { case (required, candidate) => checkType(required, candidate) }
} else {
val varArgStartIndex = sig.children.size - 1
// First check all var args has the same type with the last argument of the signature.
if (
!requiredDataTypes
.drop(varArgStartIndex)
.forall(argType => DataTypeUtils.sameType(sig.children.last, argType))
.forall(argType => checkType(argType, sig.children.last))
) {
false
} else if (varArgStartIndex == 0) {
// No fixed args.
true
} else {
// Whether fixed args matches.
sig.children
.dropRight(1)
.zip(requiredDataTypes.dropRight(1 + requiredDataTypes.size - sig.children.size))
.forall { case (candidate, required) => DataTypeUtils.sameType(candidate, required) }
requiredDataTypes
.dropRight(1 + requiredDataTypes.size - sig.children.size)
.zip(sig.children.dropRight(1))
.forall { case (required, candidate) => checkType(required, candidate) }
}
}
}
}

private def applyCast(children: Seq[Expression], sig: UDFSignatureBase): Seq[Expression] = {
def maybeCast(expr: Expression, toType: DataType): Expression = {
if (!expr.dataType.sameType(toType)) {
Cast(expr, toType)
} else {
expr
}
}

if (!sig.variableArity) {
children.zip(sig.children).map { case (expr, toType) => maybeCast(expr, toType) }
} else {
val fixedArgs = Math.min(children.size, sig.children.size)
val newChildren = children.take(fixedArgs).zip(sig.children.take(fixedArgs)).map {
case (expr, toType) => maybeCast(expr, toType)
}
if (children.size > sig.children.size) {
val varArgType = sig.children.last
newChildren ++ children.takeRight(children.size - sig.children.size).map {
expr => maybeCast(expr, varArgType)
}
} else {
newChildren
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.expression

import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
import org.apache.gluten.tags.{SkipTestTags, UDFTest}

import org.apache.spark.SparkConf
Expand Down Expand Up @@ -88,6 +89,23 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
.sameElements(Array(Row(105L, 6, 6L, 5, 6, 11, 6L, 11L, Date.valueOf("2024-03-30")))))
}

test("test udf allow type conversion") {
withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "true") {
val df = spark.sql("""select myudf1("100"), myudf1(1), mydate('2024-03-25', 5)""")
assert(
df.collect()
.sameElements(Array(Row(105L, 6L, Date.valueOf("2024-03-30")))))
}

withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "false") {
assert(
spark
.sql("select mydate2('2024-03-25', 5)")
.collect()
.sameElements(Array(Row(Date.valueOf("2024-03-30")))))
}
}

test("test udaf") {
val df = spark.sql("""select
| myavg(1),
Expand All @@ -101,6 +119,15 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
df.collect()
.sameElements(Array(Row(1.0, 1.0, 1.0, 1.0, 1L))))
}

test("test udaf allow type conversion") {
withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "true") {
val df = spark.sql("""select myavg("1"), myavg("1.0"), mycount_if("true")""")
assert(
df.collect()
.sameElements(Array(Row(1.0, 1.0, 1L))))
}
}
}

@UDFTest
Expand Down
22 changes: 18 additions & 4 deletions cpp/velox/jni/JniUdf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ void gluten::initVeloxJniUDF(JNIEnv* env) {
udfResolverClass = createGlobalClassReferenceOrError(env, kUdfResolverClassPath.c_str());

// methods
registerUDFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDF", "(Ljava/lang/String;[B[BZ)V");
registerUDAFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDAF", "(Ljava/lang/String;[B[B[BZ)V");
registerUDFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDF", "(Ljava/lang/String;[B[BZZ)V");
registerUDAFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDAF", "(Ljava/lang/String;[B[B[BZZ)V");
}

void gluten::finalizeVeloxJniUDF(JNIEnv* env) {
Expand Down Expand Up @@ -71,9 +71,23 @@ void gluten::jniGetFunctionSignatures(JNIEnv* env) {
signature->intermediateType.length(),
reinterpret_cast<const jbyte*>(signature->intermediateType.c_str()));
env->CallVoidMethod(
instance, registerUDAFMethod, name, returnType, argTypes, intermediateType, signature->variableArity);
instance,
registerUDAFMethod,
name,
returnType,
argTypes,
intermediateType,
signature->variableArity,
signature->allowTypeConversion);
} else {
env->CallVoidMethod(instance, registerUDFMethod, name, returnType, argTypes, signature->variableArity);
env->CallVoidMethod(
instance,
registerUDFMethod,
name,
returnType,
argTypes,
signature->variableArity,
signature->allowTypeConversion);
}
checkException(env);
}
Expand Down
1 change: 1 addition & 0 deletions cpp/velox/udf/Udaf.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct UdafEntry {

const char* intermediateType{nullptr};
bool variableArity{false};
bool allowTypeConversion{false};
};

#define GLUTEN_GET_NUM_UDAF getNumUdaf
Expand Down
1 change: 1 addition & 0 deletions cpp/velox/udf/Udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct UdfEntry {
const char** argTypes;

bool variableArity{false};
bool allowTypeConversion{false};
};

#define GLUTEN_GET_NUM_UDF getNumUdf
Expand Down
Loading

0 comments on commit 43d0ff9

Please sign in to comment.