Skip to content

Commit

Permalink
[CH] Fix GlutenLiteralExpressionSuite and GlutenMathExpressionsSuite (#…
Browse files Browse the repository at this point in the history
…7235)

* fix failed uts

* Update CommonScalarFunctionParser.cpp

* override checkResult for ch backend
  • Loading branch information
taiyang-li authored Sep 15, 2024
1 parent 49ccdbd commit 399a91b
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ object CHExpressionUtil {
STACK -> DefaultValidator(),
TRANSFORM_KEYS -> DefaultValidator(),
TRANSFORM_VALUES -> DefaultValidator(),
RAISE_ERROR -> DefaultValidator()
RAISE_ERROR -> DefaultValidator(),
WIDTH_BUCKET -> DefaultValidator()
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Sign, sign, sign);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Radians, radians, radians);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Greatest, greatest, sparkGreatest);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Least, least, sparkLeast);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ShiftLeft, shiftleft, bitShiftLeft);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ShiftRight, shiftright, bitShiftRight);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rand, rand, randCanonical);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Bin, bin, sparkBin);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rint, rint, sparkRint);
Expand Down
105 changes: 105 additions & 0 deletions cpp-ch/local-engine/Parser/scalar_function_parser/shift.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Parser/FunctionParser.h>
#include <DataTypes/IDataType.h>
#include <Common/CHUtil.h>
#include <Core/Field.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeNullable.h>
#include <Functions/FunctionHelpers.h>

namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}

namespace local_engine
{

class FunctionParserShiftBase : public FunctionParser
{
public:
explicit FunctionParserShiftBase(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) { }
~FunctionParserShiftBase() override = default;

virtual String getCHFunctionName() const = 0;

const ActionsDAG::Node * parse(
const substrait::Expression_ScalarFunction & substrait_func,
ActionsDAG & actions_dag) const override
{
/// parse spark shiftxxx(expr, n) as
/// If expr has long type -> CH bitShiftxxx(expr, pmod(n, 64))
/// Otherwise -> CH bitShiftxxx(expr, pmod(n, 32))
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
if (parsed_args.size() != 2)
throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName());


auto input_type = removeNullable(parsed_args[0]->result_type);
WhichDataType which(input_type);
const ActionsDAG::Node * base_node = nullptr;
if (which.isInt64())
{
base_node = addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeInt32>(), 64);
}
else if (which.isInt32())
{
base_node = addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeInt32>(), 32);
}
else
throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "First argument for function {} must be an long or integer", getName());

const auto * pmod_node = toFunctionNode(actions_dag, "pmod", {parsed_args[1], base_node});
auto ch_function_name = getCHFunctionName();
const auto * shift_node = toFunctionNode(actions_dag, ch_function_name, {parsed_args[0], pmod_node});
return convertNodeTypeIfNeeded(substrait_func, shift_node, actions_dag);
}
};

class FunctionParserShiftLeft : public FunctionParserShiftBase
{
public:
explicit FunctionParserShiftLeft(SerializedPlanParser * plan_parser_) : FunctionParserShiftBase(plan_parser_) { }
~FunctionParserShiftLeft() override = default;

static constexpr auto name = "shiftleft";
String getName() const override { return name; }

String getCHFunctionName() const override { return "bitShiftLeft"; }
};
static FunctionParserRegister<FunctionParserShiftLeft> register_shiftleft;

class FunctionParserShiftRight: public FunctionParserShiftBase
{
public:
explicit FunctionParserShiftRight(SerializedPlanParser * plan_parser_) : FunctionParserShiftBase(plan_parser_) { }
~FunctionParserShiftRight() override = default;

static constexpr auto name = "shiftright";
String getName() const override { return name; }

String getCHFunctionName() const override { return "bitShiftRight"; }
};
static FunctionParserRegister<FunctionParserShiftRight> register_shiftright;


}
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ class FunctionParserShiftRightUnsigned : public FunctionParser
{
/// parse shiftrightunsigned(a, b) as
/// if (isInteger(a))
/// bitShiftRight(a::UInt32, b::UInt32)
/// bitShiftRight(a::UInt32, pmod(b, 32))
/// else if (isLong(a))
/// bitShiftRight(a::UInt64, b::UInt64)
/// bitShiftRight(a::UInt64, pmod(b, 32))
/// else
/// throw Exception

Expand All @@ -55,26 +55,27 @@ class FunctionParserShiftRightUnsigned : public FunctionParser

const auto * a = parsed_args[0];
const auto * b = parsed_args[1];
const auto * new_a = a;
const auto * new_b = b;

WhichDataType which(removeNullable(a->result_type));
const ActionsDAG::Node * base_node = nullptr;
const ActionsDAG::Node * unsigned_a_node = nullptr;
if (which.isInt32())
{
base_node = addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeUInt32>(), 32);
const auto * uint32_type_node = addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeString>(), "Nullable(UInt32)");
new_a = toFunctionNode(actions_dag, "CAST", {a, uint32_type_node});
new_b = toFunctionNode(actions_dag, "CAST", {b, uint32_type_node});
unsigned_a_node = toFunctionNode(actions_dag, "CAST", {a, uint32_type_node});
}
else if (which.isInt64())
{
base_node = addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeUInt32>(), 64);
const auto * uint64_type_node = addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeString>(), "Nullable(UInt64)");
new_a = toFunctionNode(actions_dag, "CAST", {a, uint64_type_node});
new_b = toFunctionNode(actions_dag, "CAST", {b, uint64_type_node});
unsigned_a_node = toFunctionNode(actions_dag, "CAST", {a, uint64_type_node});
}
else
throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Function {} requires integer or long as first argument", getName());

const auto * result = toFunctionNode(actions_dag, "bitShiftRight", {new_a, new_b});
const auto * pmod_node = toFunctionNode(actions_dag, "pmod", {b, base_node});
const auto * result = toFunctionNode(actions_dag, "bitShiftRight", {unsigned_a_node, pmod_node});
return convertNodeTypeIfNeeded(substrait_func, result, actions_dag);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -788,32 +788,12 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("SPARK-35728: Check multiply/divide of day-time intervals of any fields by numeric")
.exclude("SPARK-35778: Check multiply/divide of year-month intervals of any fields by numeric")
enableSuite[GlutenLiteralExpressionSuite]
.exclude("null")
.exclude("default")
.exclude("decimal")
.exclude("array")
.exclude("seq")
.exclude("map")
.exclude("struct")
.exclude("SPARK-35664: construct literals from java.time.LocalDateTime")
.exclude("SPARK-34605: construct literals from java.time.Duration")
.exclude("SPARK-34605: construct literals from arrays of java.time.Duration")
.exclude("SPARK-34615: construct literals from java.time.Period")
.exclude("SPARK-34615: construct literals from arrays of java.time.Period")
.exclude("SPARK-35871: Literal.create(value, dataType) should support fields")
.exclude("SPARK-37967: Literal.create support ObjectType")
enableSuite[GlutenMathExpressionsSuite]
.exclude("tanh")
.exclude("unhex")
.exclude("atan2")
.exclude("round/bround/floor/ceil")
.exclude("SPARK-36922: Support ANSI intervals for SIGN/SIGNUM")
.exclude("SPARK-35926: Support YearMonthIntervalType in width-bucket function")
.exclude("SPARK-35925: Support DayTimeIntervalType in width-bucket function")
.exclude("SPARK-37388: width_bucket")
.exclude("shift left")
.exclude("shift right")
.exclude("shift right unsigned")
.exclude("unhex") // https://github.com/apache/incubator-gluten/issues/7232
.exclude("round/bround/floor/ceil") // https://github.com/apache/incubator-gluten/issues/7233
.exclude("atan2") // https://github.com/apache/incubator-gluten/issues/7233
enableSuite[GlutenMiscExpressionsSuite]
enableSuite[GlutenNondeterministicSuite]
.exclude("MonotonicallyIncreasingID")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,41 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.GlutenTestsTrait
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

class GlutenLiteralExpressionSuite extends LiteralExpressionSuite with GlutenTestsTrait {}
import java.nio.charset.StandardCharsets
import java.time.{Instant, LocalDate}

class GlutenLiteralExpressionSuite extends LiteralExpressionSuite with GlutenTestsTrait {
testGluten("default") {
checkEvaluation(Literal.default(BooleanType), false)
checkEvaluation(Literal.default(ByteType), 0.toByte)
checkEvaluation(Literal.default(ShortType), 0.toShort)
checkEvaluation(Literal.default(IntegerType), 0)
checkEvaluation(Literal.default(LongType), 0L)
checkEvaluation(Literal.default(FloatType), 0.0f)
checkEvaluation(Literal.default(DoubleType), 0.0)
checkEvaluation(Literal.default(StringType), "")
checkEvaluation(Literal.default(BinaryType), "".getBytes(StandardCharsets.UTF_8))
checkEvaluation(Literal.default(DecimalType.USER_DEFAULT), Decimal(0))
checkEvaluation(Literal.default(DecimalType.SYSTEM_DEFAULT), Decimal(0))
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "false") {
checkEvaluation(Literal.default(DateType), DateTimeUtils.toJavaDate(0))
checkEvaluation(Literal.default(TimestampType), DateTimeUtils.toJavaTimestamp(0L))
}
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
checkEvaluation(Literal.default(DateType), LocalDate.ofEpochDay(0))
checkEvaluation(Literal.default(TimestampType), Instant.ofEpochSecond(0))
}
checkEvaluation(Literal.default(CalendarIntervalType), new CalendarInterval(0, 0, 0L))
checkEvaluation(Literal.default(YearMonthIntervalType()), 0)
checkEvaluation(Literal.default(DayTimeIntervalType()), 0L)
checkEvaluation(Literal.default(ArrayType(StringType)), Array())
checkEvaluation(Literal.default(MapType(IntegerType, StringType)), Map())
checkEvaluation(Literal.default(StructType(StructField("a", StringType) :: Nil)), Row(""))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,44 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.gluten.utils.BackendTestUtils

import org.apache.spark.sql.GlutenQueryTestUtil.isNaNOrInf
import org.apache.spark.sql.GlutenTestsTrait
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._

import org.apache.commons.math3.util.Precision

import java.nio.charset.StandardCharsets

class GlutenMathExpressionsSuite extends MathExpressionsSuite with GlutenTestsTrait {
override protected def checkResult(
result: Any,
expected: Any,
exprDataType: DataType,
exprNullable: Boolean): Boolean = {
if (BackendTestUtils.isVeloxBackendLoaded()) {
super.checkResult(result, expected, exprDataType, exprNullable)
} else {
// The result is null for a non-nullable expression
assert(result != null || exprNullable, "exprNullable should be true if result is null")
(result, expected) match {
case (result: Double, expected: Double) =>
if (
(isNaNOrInf(result) || isNaNOrInf(expected))
|| (result == -0.0) || (expected == -0.0)
) {
java.lang.Double.doubleToRawLongBits(result) ==
java.lang.Double.doubleToRawLongBits(expected)
} else {
Precision.equalsWithRelativeTolerance(result, expected, 0.00001d) ||
Precision.equals(result, expected, 0.00001d)
}
case _ =>
super.checkResult(result, expected, exprDataType, exprNullable)
}
}
}

testGluten("round/bround/floor/ceil") {
val scales = -6 to 6
val doublePi: Double = math.Pi
Expand Down Expand Up @@ -284,4 +317,22 @@ class GlutenMathExpressionsSuite extends MathExpressionsSuite with GlutenTestsTr
checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(3.1411), Literal(-3))), Decimal(1000))
checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(135.135), Literal(-2))), Decimal(200))
}

testGluten("unhex") {
checkEvaluation(Unhex(Literal.create(null, StringType)), null)
checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes(StandardCharsets.UTF_8))
checkEvaluation(Unhex(Literal("")), new Array[Byte](0))
checkEvaluation(Unhex(Literal("F")), Array[Byte](15))
checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1))

// checkEvaluation(Unhex(Literal("GG")), null)
checkEvaluation(Unhex(Literal("123")), Array[Byte](1, 35))
checkEvaluation(Unhex(Literal("12345")), Array[Byte](1, 35, 69))
// scalastyle:off
// Turn off scala style for non-ascii chars
checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes(StandardCharsets.UTF_8))
// checkEvaluation(Unhex(Literal("三重的")), null)
// scalastyle:on
checkConsistencyBetweenInterpretedAndCodegen(Unhex, StringType)
}
}

0 comments on commit 399a91b

Please sign in to comment.