Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-6067][CH][MINOR][UT] Pass backends-clickhouse ut in Spark 3.5 #6623

Merged
merged 2 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ class GlutenClickHouseDecimalSuite
private val decimalTable: String = "decimal_table"
private val decimalTPCHTables: Seq[(DecimalType, Seq[Int])] = Seq.apply(
(DecimalType.apply(9, 4), Seq()),
// 1: ch decimal avg is float
(DecimalType.apply(18, 8), Seq()),
// 1: ch decimal avg is float, 3/10: all value is null and compare with limit
(DecimalType.apply(38, 19), Seq(3, 10))
// 3/10: all value is null and compare with limit
// 1 Spark 3.5
(DecimalType.apply(38, 19), if (isSparkVersionLE("3.3")) Seq(3, 10) else Seq(1, 3, 10))
)

private def createDecimalTables(dataType: DecimalType): Unit = {
Expand Down Expand Up @@ -343,27 +343,22 @@ class GlutenClickHouseDecimalSuite
decimalTPCHTables.foreach {
dt =>
{
val fallBack = (sql_num == 16 || sql_num == 21)
val compareResult = !dt._2.contains(sql_num)
val native = if (fallBack) "fallback" else "native"
val compare = if (compareResult) "compare" else "noCompare"
val PrecisionLoss = s"allowPrecisionLoss=$allowPrecisionLoss"
val decimalType = dt._1
test(s"""TPCH Decimal(${decimalType.precision},${decimalType.scale})
| Q$sql_num[allowPrecisionLoss=$allowPrecisionLoss]""".stripMargin) {
var noFallBack = true
var compareResult = true
if (sql_num == 16 || sql_num == 21) {
noFallBack = false
}

if (dt._2.contains(sql_num)) {
compareResult = false
}

| Q$sql_num[$PrecisionLoss,$native,$compare]""".stripMargin) {
spark.sql(s"use decimal_${decimalType.precision}_${decimalType.scale}")
withSQLConf(
(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key, allowPrecisionLoss)) {
runTPCHQuery(
sql_num,
tpchQueries,
compareResult = compareResult,
noFallBack = noFallBack) { _ => {} }
noFallBack = !fallBack) { _ => {} }
}
spark.sql(s"use default")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1051,8 +1051,12 @@ class GlutenClickHouseHiveTableSuite
spark.sql(
s"CREATE FUNCTION my_add as " +
s"'org.apache.hadoop.hive.contrib.udf.example.UDFExampleAdd2' USING JAR '$jarUrl'")
runQueryAndCompare("select MY_ADD(id, id+1) from range(10)")(
checkGlutenOperatorMatch[ProjectExecTransformer])
if (isSparkVersionLE("3.3")) {
runQueryAndCompare("select MY_ADD(id, id+1) from range(10)")(
checkGlutenOperatorMatch[ProjectExecTransformer])
} else {
runQueryAndCompare("select MY_ADD(id, id+1) from range(10)", noFallBack = false)(_ => {})
}
}

test("GLUTEN-4333: fix CSE in aggregate operator") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ class GlutenClickHouseNativeWriteTableSuite
("timestamp_field", "timestamp")
)
def excludeTimeFieldForORC(format: String): Seq[String] = {
if (format.equals("orc") && isSparkVersionGE("3.4")) {
if (format.equals("orc") && isSparkVersionGE("3.5")) {
// FIXME:https://github.com/apache/incubator-gluten/pull/6507
fields.keys.filterNot(_.equals("timestamp_field")).toSeq
} else {
Expand Down Expand Up @@ -913,7 +913,7 @@ class GlutenClickHouseNativeWriteTableSuite
(table_name, create_sql, insert_sql)
},
(table_name, _) =>
if (isSparkVersionGE("3.4")) {
if (isSparkVersionGE("3.5")) {
compareResultsAgainstVanillaSpark(
s"select * from $table_name",
compareResult = true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.gluten.execution

import org.apache.gluten.GlutenConfig
import org.apache.gluten.benchmarks.GenTPCDSTableScripts
import org.apache.gluten.utils.UTSystemParameters
import org.apache.gluten.utils.{Arm, UTSystemParameters}

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -46,8 +46,8 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
rootPath + "../../../../gluten-core/src/test/resources/tpcds-queries/tpcds.queries.original"
protected val queriesResults: String = rootPath + "tpcds-decimal-queries-output"

/** Return values: (sql num, is fall back, skip fall back assert) */
def tpcdsAllQueries(isAqe: Boolean): Seq[(String, Boolean, Boolean)] =
/** Return values: (sql num, is fall back) */
def tpcdsAllQueries(isAqe: Boolean): Seq[(String, Boolean)] =
Range
.inclusive(1, 99)
.flatMap(
Expand All @@ -57,37 +57,37 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
} else {
Seq("q" + "%d".format(queryNum))
}
val noFallBack = queryNum match {
case i if !isAqe && (i == 10 || i == 16 || i == 35 || i == 94) =>
// q10 smj + existence join
// q16 smj + left semi + not condition
// q35 smj + existence join
// Q94 BroadcastHashJoin, LeftSemi, NOT condition
(false, false)
case i if isAqe && (i == 16 || i == 94) =>
(false, false)
case other => (true, false)
}
sqlNums.map((_, noFallBack._1, noFallBack._2))
val native = !fallbackSets(isAqe).contains(queryNum)
sqlNums.map((_, native))
})

// FIXME "q17", stddev_samp inconsistent results, CH return NaN, Spark return null
protected def fallbackSets(isAqe: Boolean): Set[Int] = {
val more = if (isSparkVersionGE("3.5")) Set(44, 67, 70) else Set.empty[Int]

// q16 smj + left semi + not condition
// Q94 BroadcastHashJoin, LeftSemi, NOT condition
if (isAqe) {
Set(16, 94) | more
} else {
// q10, q35 smj + existence join
Set(10, 16, 35, 94) | more
}
}
protected def excludedTpcdsQueries: Set[String] = Set(
"q61", // inconsistent results
"q66", // inconsistent results
"q67" // inconsistent results
"q66" // inconsistent results
)

def executeTPCDSTest(isAqe: Boolean): Unit = {
tpcdsAllQueries(isAqe).foreach(
s =>
if (excludedTpcdsQueries.contains(s._1)) {
ignore(s"TPCDS ${s._1.toUpperCase()}") {
runTPCDSQuery(s._1, noFallBack = s._2, skipFallBackAssert = s._3) { df => }
runTPCDSQuery(s._1, noFallBack = s._2) { df => }
}
} else {
test(s"TPCDS ${s._1.toUpperCase()}") {
runTPCDSQuery(s._1, noFallBack = s._2, skipFallBackAssert = s._3) { df => }
val tag = if (s._2) "Native" else "Fallback"
test(s"TPCDS[$tag] ${s._1.toUpperCase()}") {
runTPCDSQuery(s._1, noFallBack = s._2) { df => }
}
})
}
Expand Down Expand Up @@ -152,7 +152,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
}

override protected def afterAll(): Unit = {
ClickhouseSnapshot.clearAllFileStatusCache
ClickhouseSnapshot.clearAllFileStatusCache()
DeltaLog.clearCache()

try {
Expand Down Expand Up @@ -183,11 +183,10 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
tpcdsQueries: String = tpcdsQueries,
queriesResults: String = queriesResults,
compareResult: Boolean = true,
noFallBack: Boolean = true,
skipFallBackAssert: Boolean = false)(customCheck: DataFrame => Unit): Unit = {
noFallBack: Boolean = true)(customCheck: DataFrame => Unit): Unit = {

val sqlFile = tpcdsQueries + "/" + queryNum + ".sql"
val sql = Source.fromFile(new File(sqlFile), "UTF-8").mkString
val sql = Arm.withResource(Source.fromFile(new File(sqlFile), "UTF-8"))(_.mkString)
val df = spark.sql(sql)

if (compareResult) {
Expand All @@ -212,13 +211,13 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
// using WARN to guarantee printed
log.warn(s"query: $queryNum, finish comparing with saved result")
} else {
val start = System.currentTimeMillis();
val start = System.currentTimeMillis()
val ret = df.collect()
// using WARN to guarantee printed
log.warn(s"query: $queryNum skipped comparing, time cost to collect: ${System
.currentTimeMillis() - start} ms, ret size: ${ret.length}")
}
WholeStageTransformerSuite.checkFallBack(df, noFallBack, skipFallBackAssert)
WholeStageTransformerSuite.checkFallBack(df, noFallBack)
customCheck(df)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,10 @@ class GlutenClickHouseTPCHBucketSuite
val plans = collect(df.queryExecution.executedPlan) {
case scanExec: BasicScanExecTransformer => scanExec
}
assert(!(plans(0).asInstanceOf[FileSourceScanExecTransformer].bucketedScan))
assert(plans(0).metrics("numFiles").value === 2)
assert(plans(0).metrics("pruningTime").value === -1)
assert(plans(0).metrics("numOutputRows").value === 591673)
assert(!plans.head.asInstanceOf[FileSourceScanExecTransformer].bucketedScan)
assert(plans.head.metrics("numFiles").value === 2)
assert(plans.head.metrics("pruningTime").value === pruningTimeValueSpark)
assert(plans.head.metrics("numOutputRows").value === 591673)
})
}

Expand Down Expand Up @@ -291,7 +291,7 @@ class GlutenClickHouseTPCHBucketSuite
}

if (sparkVersion.equals("3.2")) {
assert(!(plans(11).asInstanceOf[FileSourceScanExecTransformer].bucketedScan))
assert(!plans(11).asInstanceOf[FileSourceScanExecTransformer].bucketedScan)
} else {
assert(plans(11).asInstanceOf[FileSourceScanExecTransformer].bucketedScan)
}
Expand Down Expand Up @@ -327,14 +327,14 @@ class GlutenClickHouseTPCHBucketSuite
.isInstanceOf[InputIteratorTransformer])

if (sparkVersion.equals("3.2")) {
assert(!(plans(2).asInstanceOf[FileSourceScanExecTransformer].bucketedScan))
assert(!plans(2).asInstanceOf[FileSourceScanExecTransformer].bucketedScan)
} else {
assert(plans(2).asInstanceOf[FileSourceScanExecTransformer].bucketedScan)
}
assert(plans(2).metrics("numFiles").value === 2)
assert(plans(2).metrics("numOutputRows").value === 3111)

assert(!(plans(3).asInstanceOf[FileSourceScanExecTransformer].bucketedScan))
assert(!plans(3).asInstanceOf[FileSourceScanExecTransformer].bucketedScan)
assert(plans(3).metrics("numFiles").value === 2)
assert(plans(3).metrics("numOutputRows").value === 72678)
})
Expand Down Expand Up @@ -366,12 +366,12 @@ class GlutenClickHouseTPCHBucketSuite
}
// bucket join
assert(
plans(0)
plans.head
.asInstanceOf[HashJoinLikeExecTransformer]
.left
.isInstanceOf[ProjectExecTransformer])
assert(
plans(0)
plans.head
.asInstanceOf[HashJoinLikeExecTransformer]
.right
.isInstanceOf[ProjectExecTransformer])
Expand Down Expand Up @@ -409,10 +409,10 @@ class GlutenClickHouseTPCHBucketSuite
val plans = collect(df.queryExecution.executedPlan) {
case scanExec: BasicScanExecTransformer => scanExec
}
assert(!(plans(0).asInstanceOf[FileSourceScanExecTransformer].bucketedScan))
assert(plans(0).metrics("numFiles").value === 2)
assert(plans(0).metrics("pruningTime").value === -1)
assert(plans(0).metrics("numOutputRows").value === 11618)
assert(!plans.head.asInstanceOf[FileSourceScanExecTransformer].bucketedScan)
assert(plans.head.metrics("numFiles").value === 2)
assert(plans.head.metrics("pruningTime").value === pruningTimeValueSpark)
assert(plans.head.metrics("numOutputRows").value === 11618)
})
}

Expand All @@ -425,12 +425,12 @@ class GlutenClickHouseTPCHBucketSuite
}
// bucket join
assert(
plans(0)
plans.head
.asInstanceOf[HashJoinLikeExecTransformer]
.left
.isInstanceOf[FilterExecTransformerBase])
assert(
plans(0)
plans.head
.asInstanceOf[HashJoinLikeExecTransformer]
.right
.isInstanceOf[ProjectExecTransformer])
Expand Down Expand Up @@ -585,7 +585,7 @@ class GlutenClickHouseTPCHBucketSuite
def checkResult(df: DataFrame, exceptedResult: Seq[Row]): Unit = {
// check the result
val result = df.collect()
assert(result.size == exceptedResult.size)
assert(result.length == exceptedResult.size)
val sortedRes = result.map {
s =>
Row.fromSeq(s.toSeq.map {
Expand Down Expand Up @@ -786,7 +786,7 @@ class GlutenClickHouseTPCHBucketSuite
|order by l_orderkey, l_returnflag, t
|limit 10
|""".stripMargin
runSql(SQL7, false)(
runSql(SQL7, noFallBack = false)(
df => {
checkResult(
df,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf}
import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig

import org.apache.commons.io.FileUtils
import org.scalatest.Tag

import java.io.File

Expand Down Expand Up @@ -177,13 +178,23 @@ class GlutenClickHouseWholeStageTransformerSuite extends WholeStageTransformerSu
super.beforeAll()
}

protected val rootPath = this.getClass.getResource("/").getPath
protected val basePath = rootPath + "tests-working-home"
protected val warehouse = basePath + "/spark-warehouse"
protected val metaStorePathAbsolute = basePath + "/meta"
protected val hiveMetaStoreDB = metaStorePathAbsolute + "/metastore_db"
protected val rootPath: String = this.getClass.getResource("/").getPath
protected val basePath: String = rootPath + "tests-working-home"
protected val warehouse: String = basePath + "/spark-warehouse"
protected val metaStorePathAbsolute: String = basePath + "/meta"
protected val hiveMetaStoreDB: String = metaStorePathAbsolute + "/metastore_db"

final override protected val resourcePath: String = "" // ch not need this
override protected val fileFormat: String = "parquet"

protected def testSparkVersionLE33(testName: String, testTag: Tag*)(testFun: => Any): Unit = {
if (isSparkVersionLE("3.3")) {
test(testName, testTag: _*)(testFun)
} else {
ignore(s"[$SPARK_VERSION_SHORT]-$testName", testTag: _*)(testFun)
}
}

lazy val pruningTimeValueSpark: Int = if (isSparkVersionLE("3.3")) -1 else 0
}
// scalastyle:off line.size.limit
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class GlutenClickhouseCountDistinctSuite extends GlutenClickHouseWholeStageTrans
val sql = s"""
select count(distinct(a,b)) , try_add(c,b) from
values (0, null,1), (0,null,2), (1, 1,4) as data(a,b,c) group by try_add(c,b)
""";
"""
val df = spark.sql(sql)
WholeStageTransformerSuite.checkFallBack(df, noFallback = false)
WholeStageTransformerSuite.checkFallBack(df, noFallback = isSparkVersionGE("3.5"))
}

test("check count distinct with filter") {
Expand Down
Loading
Loading