diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index f6d18d7a2228..1dd815b6d78d 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -203,8 +203,6 @@ object CHExpressionUtil { TO_UTC_TIMESTAMP -> UtcTimestampValidator(), FROM_UTC_TIMESTAMP -> UtcTimestampValidator(), STACK -> DefaultValidator(), - TRANSFORM_KEYS -> DefaultValidator(), - TRANSFORM_VALUES -> DefaultValidator(), RAISE_ERROR -> DefaultValidator(), WIDTH_BUCKET -> DefaultValidator() ) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index dbe8852290aa..39b5421f5d68 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -860,4 +860,16 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS val sql = "select cast(id % 2 = 1 as string) from range(10)" compareResultsAgainstVanillaSpark(sql, true, { _ => }) } + + test("Test transform_keys/transform_values") { + val sql = """ + |select + | transform_keys(map_from_arrays(array(id+1, id+2, id+3), + | array(1, id+2, 3)), (k, v) -> k + 1), + | transform_values(map_from_arrays(array(id+1, id+2, id+3), + | array(1, id+2, 3)), (k, v) -> v + 1) + |from range(10) + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + } } diff --git a/cpp-ch/local-engine/Parser/FunctionParser.cpp b/cpp-ch/local-engine/Parser/FunctionParser.cpp index 7e794dabec64..6ea5148ea406 100644 --- a/cpp-ch/local-engine/Parser/FunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/FunctionParser.cpp @@ -181,9 +181,7 @@ FunctionParserPtr FunctionParserFactory::get(const String & name, ParserContextP { auto res = tryGet(name, ctx); if (!res) - { throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "Unknown function parser {}", name); - } return res; } diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/mapHighOrderFunctions.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/mapHighOrderFunctions.cpp new file mode 100644 index 000000000000..e559980f8548 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/mapHighOrderFunctions.cpp @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "DataTypes/DataTypeMap.h" + +namespace DB::ErrorCodes +{ + extern const int SIZES_OF_COLUMNS_DOESNT_MATCH; + extern const int BAD_ARGUMENTS; +} + +namespace local_engine +{ + +template +class FunctionParserMapTransformImpl : public FunctionParser +{ +public: + static constexpr auto name = transform_keys ? "transform_keys" : "transform_values"; + String getName() const override { return name; } + + explicit FunctionParserMapTransformImpl(ParserContextPtr parser_context_) : FunctionParser(parser_context_) {} + ~FunctionParserMapTransformImpl() override = default; + + const DB::ActionsDAG::Node * + parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override + { + /// Parse spark transform_keys(map, func) as CH mapFromArrays(arrayMap(func, cast(map as array)), mapValues(map)) + /// Parse spark transform_values(map, func) as CH mapFromArrays(mapKeys(map), arrayMap(func, cast(map as array))) + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); + if (parsed_args.size() != 2) + throw DB::Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "{} function must have three arguments", getName()); + + auto lambda_args = collectLambdaArguments(parser_context, substrait_func.arguments()[1].value().scalar_function()); + if (lambda_args.size() != 2) + throw DB::Exception( + DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "The lambda function in {} must have two arguments", getName()); + + const auto * map_node = parsed_args[0]; + const auto * func_node = parsed_args[1]; + const auto & map_type = map_node->result_type; + auto array_type = checkAndGetDataType(removeNullable(map_type).get())->getNestedType(); + if (map_type->isNullable()) + array_type = std::make_shared(array_type); + const auto * array_node = ActionsDAGUtil::convertNodeTypeIfNeeded(actions_dag, map_node, array_type); + const auto * transformed_node = toFunctionNode(actions_dag, "arrayMap", {func_node, array_node}); + + const DB::ActionsDAG::Node * result_node = nullptr; + if constexpr (transform_keys) + { + const auto * nontransformed_node = toFunctionNode(actions_dag, "mapValues", {parsed_args[0]}); + result_node = toFunctionNode(actions_dag, "mapFromArrays", {transformed_node, nontransformed_node}); + } + else + { + const auto * nontransformed_node = toFunctionNode(actions_dag, "mapKeys", {parsed_args[0]}); + result_node = toFunctionNode(actions_dag, "mapFromArrays", {nontransformed_node, transformed_node}); + } + return convertNodeTypeIfNeeded(substrait_func, result_node, actions_dag); + } +}; + +using FunctionParserTransformKeys = FunctionParserMapTransformImpl; +using FunctionParserTransformValues = FunctionParserMapTransformImpl; + +static FunctionParserRegister register_transform_keys; +static FunctionParserRegister register_transform_values; +} \ No newline at end of file diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 2eb5bd11ffbe..71e32bdccf7a 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -166,6 +166,10 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("aggregate function - array for non-primitive type") .exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union") .exclude("SPARK-24734: Fix containsNull of Concat for array type") + .exclude("transform keys function - primitive data types") + .exclude("transform keys function - Invalid lambda functions and exceptions") + .exclude("transform values function - test primitive data types") + .exclude("transform values function - test empty") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude( diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index a7bf5d4da903..ce09d0f59580 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -184,6 +184,10 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("aggregate function - array for non-primitive type") .exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union") .exclude("SPARK-24734: Fix containsNull of Concat for array type") + .exclude("transform keys function - primitive data types") + .exclude("transform keys function - Invalid lambda functions and exceptions") + .exclude("transform values function - test primitive data types") + .exclude("transform values function - test empty") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude( diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index b7e3905740fb..71e2f6375e8f 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -186,6 +186,10 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("aggregate function - array for non-primitive type") .exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union") .exclude("SPARK-24734: Fix containsNull of Concat for array type") + .exclude("transform keys function - primitive data types") + .exclude("transform keys function - Invalid lambda functions and exceptions") + .exclude("transform values function - test primitive data types") + .exclude("transform values function - test empty") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude( diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 8ce145735dc3..f08b66972c04 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -186,6 +186,10 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("aggregate function - array for non-primitive type") .exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union") .exclude("SPARK-24734: Fix containsNull of Concat for array type") + .exclude("transform keys function - primitive data types") + .exclude("transform keys function - Invalid lambda functions and exceptions") + .exclude("transform values function - test primitive data types") + .exclude("transform values function - test empty") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude(