Skip to content

Commit

Permalink
[GLUTEN-5613][CH] Fix CH function SparkCheckoverflow return type not …
Browse files Browse the repository at this point in the history
…equals with spark (apache#5614)

* Fix CH function SparkCheckoverflow return type not equals with spark

* fix style
  • Loading branch information
loneylee authored May 8, 2024
1 parent 537a702 commit d9fe381
Showing 1 changed file with 172 additions and 105 deletions.
277 changes: 172 additions & 105 deletions cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,24 @@
* limitations under the License.
*/
#include "SparkFunctionCheckDecimalOverflow.h"
#include <Columns/ColumnConst.h>

#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>


namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
extern const int TYPE_MISMATCH;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
extern const int TYPE_MISMATCH;
}
}

Expand All @@ -58,137 +58,204 @@ enum class CheckExceptionMode

namespace
{
/// Returns received decimal value if and Decimal value has less digits then it's Precision allow, 0 otherwise.
/// Precision could be set as second argument or omitted. If omitted function uses Decimal precision of the first argument.
template <typename Name, CheckExceptionMode mode>
class FunctionCheckDecimalOverflow : public IFunction
{
public:
static constexpr auto name = Name::name;
static constexpr auto exception_mode = mode;
/// Returns received decimal value if and Decimal value has less digits then it's Precision allow, 0 otherwise.
/// Precision could be set as second argument or omitted. If omitted function uses Decimal precision of the first argument.
template <typename Name, CheckExceptionMode mode>
class FunctionCheckDecimalOverflow : public IFunction
{
public:
static constexpr auto name = Name::name;
static constexpr auto exception_mode = mode;

static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionCheckDecimalOverflow>(); }
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionCheckDecimalOverflow>(); }

String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 3; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1, 2}; }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 3; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1, 2}; }

DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isDecimal(arguments[0]) || !isInteger(arguments[1]) || !isInteger(arguments[2]))
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} {} {} of argument of function {}",
arguments[0]->getName(),
arguments[1]->getName(),
arguments[2]->getName(),
getName());
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (!isDecimal(arguments[0].type) || !isInteger(arguments[1].type) || !isInteger(arguments[2].type))
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} {} {} of argument of function {}",
arguments[0].type->getName(),
arguments[1].type->getName(),
arguments[2].type->getName(),
getName());

if constexpr (exception_mode == CheckExceptionMode::Null)
{
if (!arguments[0]->isNullable())
return std::make_shared<DataTypeNullable>(arguments[0]);
}
UInt32 precision = extractArgument(arguments[1]);
UInt32 scale = extractArgument(arguments[2]);

return arguments[0];
auto return_type = createDecimal<DataTypeDecimal>(precision, scale);
if constexpr (exception_mode == CheckExceptionMode::Null)
{
if (!arguments[0].type->isNullable())
return std::make_shared<DataTypeNullable>(return_type);
}

ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const auto & src_column = arguments[0];
UInt32 precision = extractArgument(arguments[1]);
UInt32 scale = extractArgument(arguments[2]);
return return_type;
}

ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const auto & src_column = arguments[0];
UInt32 precision = extractArgument(arguments[1]);
UInt32 scale = extractArgument(arguments[2]);

ColumnPtr result_column;

ColumnPtr result_column;
auto call = [&](const auto & types) -> bool
{
using Types = std::decay_t<decltype(types)>;
using FromDataType = typename Types::LeftType;
using ToDataType = typename Types::RightType;

auto call = [&](const auto & types) -> bool
if constexpr (IsDataTypeDecimal<FromDataType>)
{
using Types = std::decay_t<decltype(types)>;
using Type = typename Types::RightType;
using ColVecType = ColumnDecimal<Type>;
using FromFieldType = typename FromDataType::FieldType;
using ColVecType = ColumnDecimal<FromFieldType>;

if (const ColVecType * col_vec = checkAndGetColumn<ColVecType>(src_column.column.get()))
{
executeInternal<Type>(*col_vec, result_column, input_rows_count, precision, scale);
executeInternal<FromFieldType, ToDataType>(*col_vec, result_column, input_rows_count, precision, scale);
return true;
}
}

throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal column while execute function {}", getName());
};

if (precision <= DecimalUtils::max_precision<Decimal32>)
callOnIndexAndDataType<DataTypeDecimal<Decimal32>>(src_column.type->getTypeId(), call);
else if (precision <= DecimalUtils::max_precision<Decimal64>)
callOnIndexAndDataType<DataTypeDecimal<Decimal64>>(src_column.type->getTypeId(), call);
else if (precision <= DecimalUtils::max_precision<Decimal128>)
callOnIndexAndDataType<DataTypeDecimal<Decimal128>>(src_column.type->getTypeId(), call);
else
callOnIndexAndDataType<DataTypeDecimal<Decimal256>>(src_column.type->getTypeId(), call);

throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal column while execute function {}", getName());
};

callOnBasicType<void, false, false, true, false>(src_column.type->getTypeId(), call);
if (!result_column)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong call for {} with {}", getName(), src_column.type->getName());
if (!result_column)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong call for {} with {}", getName(), src_column.type->getName());

return result_column;
return result_column;
}

private:
template <typename T, typename ToDataType>
static void executeInternal(
const ColumnDecimal<T> & col_source, ColumnPtr & result_column, size_t input_rows_count, UInt32 precision, UInt32 scale_to)
{
using ToFieldType = typename ToDataType::FieldType;
using ToColumnType = typename ToDataType::ColumnType;

ColumnUInt8::MutablePtr col_null_map_to;
ColumnUInt8::Container * vec_null_map_to [[maybe_unused]] = nullptr;
auto scale_from = col_source.getScale();

if constexpr (exception_mode == CheckExceptionMode::Null)
{
col_null_map_to = ColumnUInt8::create(input_rows_count, false);
vec_null_map_to = &col_null_map_to->getData();
}

private:
template <typename T>
static void executeInternal(
const ColumnDecimal<T> & col_source, ColumnPtr & result_column, size_t input_rows_count, UInt32 precision, UInt32 scale_to)
typename ToColumnType::MutablePtr col_to = ToColumnType::create(input_rows_count, scale_to);
auto & vec_to = col_to->getData();
vec_to.resize_exact(input_rows_count);

auto & datas = col_source.getData();
for (size_t i = 0; i < input_rows_count; ++i)
{
ColumnUInt8::MutablePtr col_null_map_to;
ColumnUInt8::Container * vec_null_map_to [[maybe_unused]] = nullptr;
auto scale_from = col_source.getScale();
// bool overflow = outOfDigits<T>(datas[i], precision, scale_from, scale_to);
ToFieldType result;
bool success = convertToDecimalImpl<T, ToDataType>(datas[i], precision, scale_from, scale_to, result);

if constexpr (exception_mode == CheckExceptionMode::Null)
if (success)
vec_to[i] = static_cast<ToFieldType>(result);
else
{
col_null_map_to = ColumnUInt8::create(input_rows_count, false);
vec_null_map_to = &col_null_map_to->getData();
vec_to[i] = static_cast<ToFieldType>(0);
if constexpr (exception_mode == CheckExceptionMode::Null)
(*vec_null_map_to)[i] = static_cast<UInt8>(1);
else
throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal value is overflow.");
}
}

auto & datas = col_source.getData();
for (size_t i = 0; i < input_rows_count; ++i)
{
bool overflow = outOfDigits<T>(datas[i], precision, scale_from, scale_to);
if (overflow)
{
if constexpr (exception_mode == CheckExceptionMode::Null)
(*vec_null_map_to)[i] = overflow;
else
throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal value is overflow.");
}
}
if constexpr (exception_mode == CheckExceptionMode::Null)
result_column = ColumnNullable::create(std::move(col_to), std::move(col_null_map_to));
else
result_column = std::move(col_to);
}

template <is_decimal FromFieldType, typename ToDataType>
requires(IsDataTypeDecimal<ToDataType>)
static bool convertToDecimalImpl(
const FromFieldType & decimal, UInt32 precision_to, UInt32 scale_from, UInt32 scale_to, typename ToDataType::FieldType & result)
{
if constexpr (std::is_same_v<FromFieldType, Decimal32>)
return convertDecimalsImpl<DataTypeDecimal<Decimal32>, ToDataType>(decimal, precision_to, scale_from, scale_to, result);

else if constexpr (std::is_same_v<FromFieldType, Decimal64>)
return convertDecimalsImpl<DataTypeDecimal<Decimal64>, ToDataType>(decimal, precision_to, scale_from, scale_to, result);
else if constexpr (std::is_same_v<FromFieldType, Decimal128>)
return convertDecimalsImpl<DataTypeDecimal<Decimal128>, ToDataType>(decimal, precision_to, scale_from, scale_to, result);
else
return convertDecimalsImpl<DataTypeDecimal<Decimal256>, ToDataType>(decimal, precision_to, scale_from, scale_to, result);
}

template <typename FromDataType, typename ToDataType>
requires(IsDataTypeDecimal<FromDataType> && IsDataTypeDecimal<ToDataType>)
static bool convertDecimalsImpl(
const typename FromDataType::FieldType & value,
UInt32 precision_to,
UInt32 scale_from,
UInt32 scale_to,
typename ToDataType::FieldType & result)
{
using FromFieldType = typename FromDataType::FieldType;
using ToFieldType = typename ToDataType::FieldType;
using MaxFieldType = std::conditional_t<(sizeof(FromFieldType) > sizeof(ToFieldType)), FromFieldType, ToFieldType>;
using MaxNativeType = typename MaxFieldType::NativeType;

typename ColumnDecimal<T>::MutablePtr col_to = ColumnDecimal<T>::create(std::move(col_source));

auto false_value = []() -> bool
{
if constexpr (exception_mode == CheckExceptionMode::Null)
result_column = ColumnNullable::create(std::move(col_to), std::move(col_null_map_to));
return false;
else
result_column = std::move(col_to);
}
throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal value is overflow.");
};

template <is_decimal T>
static bool outOfDigits(T decimal, UInt32 precision_to, UInt32 scale_from, UInt32 scale_to)
MaxNativeType converted_value;
if (scale_to > scale_from)
{
using NativeT = typename T::NativeType;
converted_value = DecimalUtils::scaleMultiplier<MaxNativeType>(scale_to - scale_from);
if (common::mulOverflow(static_cast<MaxNativeType>(value.value), converted_value, converted_value))
return false_value();
}
else if (scale_to == scale_from)
converted_value = value.value;
else
converted_value = value.value / DecimalUtils::scaleMultiplier<MaxNativeType>(scale_from - scale_to);

NativeT converted_value;
if (scale_to > scale_from)
{
converted_value = DecimalUtils::scaleMultiplier<T>(scale_to - scale_from);
if (common::mulOverflow(static_cast<NativeT>(decimal.value), converted_value, converted_value))
{
if constexpr (exception_mode == CheckExceptionMode::Null)
return false;
else
throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal value is overflow.");
}
}
else
converted_value = decimal.value / DecimalUtils::scaleMultiplier<NativeT>(scale_from - scale_to);
// if constexpr (sizeof(FromFieldType) > sizeof(ToFieldType))
// {
MaxNativeType pow10 = intExp10OfSize<MaxNativeType>(precision_to);
if (converted_value <= -pow10 || converted_value >= pow10)
return false_value();
// }

NativeT pow10 = intExp10OfSize<NativeT>(precision_to);
if (converted_value < 0)
return converted_value <= -pow10;
return converted_value >= pow10;
}
};
result = static_cast<typename ToFieldType::NativeType>(converted_value);
return true;
}
};

using FunctionCheckDecimalOverflowThrow = FunctionCheckDecimalOverflow<CheckDecimalOverflowSpark, CheckExceptionMode::Throw>;
using FunctionCheckDecimalOverflowOrNull = FunctionCheckDecimalOverflow<CheckDecimalOverflowSparkOrNull, CheckExceptionMode::Null>;
using FunctionCheckDecimalOverflowThrow = FunctionCheckDecimalOverflow<CheckDecimalOverflowSpark, CheckExceptionMode::Throw>;
using FunctionCheckDecimalOverflowOrNull = FunctionCheckDecimalOverflow<CheckDecimalOverflowSparkOrNull, CheckExceptionMode::Null>;
}

REGISTER_FUNCTION(CheckDecimalOverflowSpark)
Expand Down

0 comments on commit d9fe381

Please sign in to comment.