diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala index 865d3fa40c0f..48f44c6908e7 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -3111,6 +3111,65 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr spark.sql("drop table if exists test_7647") } + test("GLUTEN-7905 get topk of window by aggregate") { + def checkWindowGroupLimit(df: DataFrame): Unit = { + val expands = collectWithSubqueries(df.queryExecution.executedPlan) { + case e: ExpandExecTransformer if (e.child.isInstanceOf[WindowGroupLimitExecTransformer]) => + e + } + assert(expands.size == 1) + } + spark.sql("create table test_win_top (a string, b int, c int) using parquet") + spark.sql(""" + |insert into test_win_top values + |('a', 3, 3), ('a', 1, 5), ('a', 2, 2), ('a', null, null), ('a', null, 1), + |('b', 1, 1), ('b', 2, 1), + |('c', 2, 3) + |""".stripMargin) + compareResultsAgainstVanillaSpark( + """ + |select a, b, c, row_number() over (partition by a order by b desc nulls first) as r + |from test_win_top + |""".stripMargin, + true, + checkWindowGroupLimit + ) + compareResultsAgainstVanillaSpark( + """ + |select a, b, c, row_number() over (partition by a order by b desc nulls last) as r + |from test_win_top + |""".stripMargin, + true, + checkWindowGroupLimit + ) + compareResultsAgainstVanillaSpark( + """ + |select a, b, c, row_number() over (partition by a order by b asc nulls first) as r + |from test_win_top + |""".stripMargin, + true, + checkWindowGroupLimit + ) + compareResultsAgainstVanillaSpark( + """ + |select a, b, c, row_number() over (partition by a order by b asc nulls last) as r + |from test_win_top + |""".stripMargin, + true, + checkWindowGroupLimit + ) + compareResultsAgainstVanillaSpark( + """ + |select a, b, c, row_number() over (partition by a order by b , c) as r + |from test_win_top + |""".stripMargin, + true, + checkWindowGroupLimit + ) + spark.sql("drop table if exists test_win_top") + + } + test("GLUTEN-7759: Fix bug of agg pre-project push down") { val table_create_sql = "create table test_tbl_7759(id bigint, name string, day string) using parquet" diff --git a/cpp-ch/local-engine/AggregateFunctions/GroupLimitFunctions.cpp b/cpp-ch/local-engine/AggregateFunctions/GroupLimitFunctions.cpp index 93bc9fb05223..8ca64b20b701 100644 --- a/cpp-ch/local-engine/AggregateFunctions/GroupLimitFunctions.cpp +++ b/cpp-ch/local-engine/AggregateFunctions/GroupLimitFunctions.cpp @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include @@ -29,80 +30,74 @@ #include #include #include +#include +#include +#include #include #include #include #include #include -#include #include #include +namespace DB::ErrorCodes +{ +extern const int LOGICAL_ERROR; +extern const int BAD_ARGUMENTS; +} + namespace local_engine { -template -class Sorter +struct SortOrderField { -public: + size_t pos = 0; + Int8 direction = 0; + Int8 nulls_direction = 0; }; +using SortOrderFields = std::vector; struct RowNumGroupArraySortedData { public: - using Data = std::pair; + using Data = DB::Tuple; std::vector values; - // return a < b - static bool compare(const Data & lhs, const Data & rhs, const std::vector & directions, const std::vector & nulls_first) + static bool compare(const Data & lhs, const Data & rhs, const SortOrderFields & sort_orders) { - const auto & a = lhs.second; - const auto & b = rhs.second; - for (size_t i = 0; i < a.size(); ++i) + for (const auto & sort_order : sort_orders) { - bool a_is_null = a[i].isNull(); - bool b_is_null = b[i].isNull(); - if (a_is_null && b_is_null) - { + const auto & pos = sort_order.pos; + const auto & asc = sort_order.direction; + const auto & nulls_first = sort_order.nulls_direction; + bool l_is_null = lhs[pos].isNull(); + bool r_is_null = rhs[pos].isNull(); + if (l_is_null && r_is_null) continue; - } - else if (nulls_first[i]) - { - if (a_is_null) - return false; - else if (b_is_null) - return true; - else if (a[i] < b[i]) - return directions[i]; - else if (a[i] > b[i]) - return !directions[i]; - } - else - { - if (a_is_null) - return true; - else if (b_is_null) - return false; - else if (a[i] < b[i]) - return directions[i]; - else if (a[i] > b[i]) - return !directions[i]; - } + else if (l_is_null) + return nulls_first; + else if (r_is_null) + return !nulls_first; + else if (lhs[pos] < rhs[pos]) + return asc; + else if (lhs[pos] > rhs[pos]) + return !asc; } return false; } - ALWAYS_INLINE void heapReplaceTop(const std::vector & directions, const std::vector & nulls_first) + ALWAYS_INLINE void heapReplaceTop(const SortOrderFields & sort_orders) { size_t size = values.size(); if (size < 2) return; size_t child_index = 1; - if (size > 2 && compare(values[1], values[2], directions, nulls_first)) + if (size > 2 && compare(values[1], values[2], sort_orders)) ++child_index; - if (compare(values[child_index], values[0], directions, nulls_first)) + if (compare(values[child_index], values[0], sort_orders)) return; size_t current_index = 0; @@ -117,46 +112,41 @@ struct RowNumGroupArraySortedData if (child_index >= size) break; - if ((child_index + 1) < size && compare(values[child_index], values[child_index + 1], directions, nulls_first)) + if ((child_index + 1) < size && compare(values[child_index], values[child_index + 1], sort_orders)) ++child_index; - } while (!compare(values[child_index], current, directions, nulls_first)); + } while (!compare(values[child_index], current, sort_orders)); values[current_index] = current; } - ALWAYS_INLINE void - addElement(Data data, const std::vector & directions, const std::vector & nulls_first, size_t max_elements) + ALWAYS_INLINE void addElement(Data data, const SortOrderFields & sort_orders, size_t max_elements) { if (values.size() >= max_elements) { - if (!compare(data, values[0], directions, nulls_first)) + if (!compare(data, values[0], sort_orders)) return; values[0] = std::move(data); - heapReplaceTop(directions, nulls_first); + heapReplaceTop(sort_orders); return; } values.push_back(std::move(data)); - auto cmp = [&directions, &nulls_first](const Data & a, const Data & b) { return compare(a, b, directions, nulls_first); }; + auto cmp = [&sort_orders](const Data & a, const Data & b) { return compare(a, b, sort_orders); }; std::push_heap(values.begin(), values.end(), cmp); } - ALWAYS_INLINE void sortAndLimit(size_t max_elements, const std::vector & directions, const std::vector & nulls_first) + ALWAYS_INLINE void sortAndLimit(size_t max_elements, const SortOrderFields & sort_orders) { - ::sort( - values.begin(), - values.end(), - [&directions, &nulls_first](const Data & a, const Data & b) { return compare(a, b, directions, nulls_first); }); + ::sort(values.begin(), values.end(), [&sort_orders](const Data & a, const Data & b) { return compare(a, b, sort_orders); }); if (values.size() > max_elements) values.resize(max_elements); } - ALWAYS_INLINE void - insertResultInto(DB::IColumn & to, size_t max_elements, const std::vector & directions, const std::vector & nulls_first) + ALWAYS_INLINE void insertResultInto(DB::IColumn & to, size_t max_elements, const SortOrderFields & sort_orders) { auto & result_array = assert_cast(to); auto & result_array_offsets = result_array.getOffsets(); - sortAndLimit(max_elements, directions, nulls_first); + sortAndLimit(max_elements, sort_orders); result_array_offsets.push_back(result_array_offsets.back() + values.size()); @@ -165,14 +155,14 @@ struct RowNumGroupArraySortedData auto & result_array_data = result_array.getData(); for (int i = 0, sz = static_cast(values.size()); i < sz; ++i) { - auto & value = values[i].first; + auto & value = values[i]; value.push_back(i + 1); result_array_data.insert(value); } } }; -static DB::DataTypePtr getReultDataType(DB::DataTypePtr data_type) +static DB::DataTypePtr getRowNumReultDataType(DB::DataTypePtr data_type) { const auto * tuple_type = typeid_cast(data_type.get()); if (!tuple_type) @@ -184,71 +174,43 @@ static DB::DataTypePtr getReultDataType(DB::DataTypePtr data_type) auto nested_tuple_type = std::make_shared(element_types, element_names); return std::make_shared(nested_tuple_type); } + +// usage: rowNumGroupArraySorted(1, "a asc nulls first, b desc nulls last")(tuple(a,b)) class RowNumGroupArraySorted final : public DB::IAggregateFunctionDataHelper { public: - explicit RowNumGroupArraySorted(DB::DataTypePtr value_data_type, DB::DataTypePtr order_data_type, const DB::Array & parameters_) + explicit RowNumGroupArraySorted(DB::DataTypePtr data_type, const DB::Array & parameters_) : DB::IAggregateFunctionDataHelper( - {value_data_type, order_data_type}, parameters_, getReultDataType(value_data_type)) + {data_type}, parameters_, getRowNumReultDataType(data_type)) { - const auto * order_tuple = typeid_cast(order_data_type.get()); - if (!order_tuple) - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Tuple type is expected, but got: {}", order_data_type->getName()); - if (parameters_.size() != order_tuple->getElements().size() + 1) - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Number of sort directions should be equal to number of order columns"); + if (parameters_.size() != 2) + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "{} needs two parameters: limit and order clause", getName()); + const auto * tuple_type = typeid_cast(data_type.get()); + if (!tuple_type) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Tuple type is expected, but got: {}", data_type->getName()); limit = parameters_[0].safeGet(); - for (size_t i = 1; i < parameters_.size(); ++i) - { - auto direction = magic_enum::enum_cast(parameters_[i].safeGet()); - if (!direction.has_value()) - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow direction: {}", parameters_[i].safeGet()); - if (direction.value() == SortOrder::ASC_NULLS_FIRST) - { - directions.push_back(1); - nulls_first.push_back(1); - } - else if (direction.value() == SortOrder::ASC_NULLS_LAST) - { - directions.push_back(1); - nulls_first.push_back(0); - } - else if (direction.value() == SortOrder::DESC_NULLS_FIRST) - { - directions.push_back(0); - nulls_first.push_back(1); - } - else if (direction.value() == SortOrder::DESC_NULLS_LAST) - { - directions.push_back(0); - nulls_first.push_back(0); - } - else - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown sort direction: {}", magic_enum::enum_name(direction.value())); - } - DB::DataTypeTuple serialization_data_type({value_data_type, order_data_type}); - serialization = serialization_data_type.getDefaultSerialization(); + String order_by_clause = parameters_[1].safeGet(); + sort_order_fields = parseSortOrderFields(order_by_clause); + + serialization = data_type->getDefaultSerialization(); } String getName() const override { return "rowNumGroupArraySorted"; } void add(DB::AggregateDataPtr __restrict place, const DB::IColumn ** columns, size_t row_num, DB::Arena * /*arena*/) const override { - LOG_ERROR(getLogger("RowNumGroupArraySorted"), "xxx add"); auto & data = this->data(place); DB::Tuple data_tuple = (*columns[0])[row_num].safeGet(); - DB::Tuple order_tuple = (*columns[1])[row_num].safeGet(); - this->data(place).addElement( - std::make_pair(std::move(data_tuple), std::move(order_tuple)), directions, nulls_first, limit); + this->data(place).addElement(std::move(data_tuple), sort_order_fields, limit); } void merge(DB::AggregateDataPtr __restrict place, DB::ConstAggregateDataPtr rhs, DB::Arena * /*arena*/) const override { - LOG_ERROR(getLogger("RowNumGroupArraySorted"), "xxx merge"); auto & rhs_values = this->data(rhs).values; for (auto & rhs_element : rhs_values) - this->data(place).addElement(rhs_element, directions, nulls_first, limit); + this->data(place).addElement(rhs_element, sort_order_fields, limit); } void serialize(DB::ConstAggregateDataPtr __restrict place, DB::WriteBuffer & buf, std::optional /* version */) const override @@ -257,11 +219,8 @@ class RowNumGroupArraySorted final : public DB::IAggregateFunctionDataHelperserializeBinary(data, buf, {}); - } + for (const auto & value : values) + serialization->serializeBinary(value, buf, {}); } void deserialize( @@ -276,40 +235,62 @@ class RowNumGroupArraySorted final : public DB::IAggregateFunctionDataHelperdeserializeBinary(data, buf, {}); - DB::Tuple tuple_data = data.safeGet(); - DB::Tuple a = tuple_data[0].safeGet(); - DB::Tuple b = tuple_data[1].safeGet(); - auto value = std::make_pair(std::move(a), std::move(b)); - values.emplace_back(value); + values.emplace_back(data.safeGet()); } } void insertResultInto(DB::AggregateDataPtr __restrict place, DB::IColumn & to, DB::Arena * /*arena*/) const override { - LOG_ERROR(getLogger("RowNumGroupArraySorted"), "xxx insertResultInto"); - this->data(place).insertResultInto(to, limit, directions, nulls_first); + this->data(place).insertResultInto(to, limit, sort_order_fields); } bool allocatesMemoryInArena() const override { return true; } private: size_t limit = 0; - std::vector directions; - std::vector nulls_first; + SortOrderFields sort_order_fields; DB::SerializationPtr serialization; + + SortOrderFields parseSortOrderFields(const String & order_by_clause) const + { + DB::ParserOrderByExpressionList order_by_parser; + auto order_by_ast = DB::parseQuery(order_by_parser, order_by_clause, 1000, 1000, 1000); + SortOrderFields fields; + const auto expression_list_ast = assert_cast(order_by_ast.get()); + const auto & tuple_element_names = assert_cast(argument_types[0].get())->getElementNames(); + for (const auto & child : expression_list_ast->children) + { + const auto * order_by_element_ast = assert_cast(child.get()); + const auto * ident_ast = assert_cast(order_by_element_ast->children[0].get()); + const auto & ident_name = ident_ast->shortName(); + + + SortOrderField field; + field.direction = order_by_element_ast->direction == 1; + field.nulls_direction + = field.direction ? order_by_element_ast->nulls_direction == -1 : order_by_element_ast->nulls_direction == 1; + + auto name_pos = std::find(tuple_element_names.begin(), tuple_element_names.end(), ident_name); + if (name_pos == tuple_element_names.end()) + { + throw DB::Exception( + DB::ErrorCodes::BAD_ARGUMENTS, "Not found column {} in tuple {}", ident_name, argument_types[0]->getName()); + } + field.pos = std::distance(tuple_element_names.begin(), name_pos); + + fields.push_back(field); + } + return fields; + } }; DB::AggregateFunctionPtr createAggregateFunctionRowNumGroupArray( const std::string & name, const DB::DataTypes & argument_types, const DB::Array & parameters, const DB::Settings *) { - std::string query = "x desc nulls first, y asc"; - DB::ParserOrderByExpressionList order_parser; - auto ast = DB::parseQuery(order_parser, query, 100, 100, 100); - LOG_DEBUG(getLogger("RowNumGroupArraySorted"), "xxx test ast\n{}", DB::queryToString(ast)); - for (const auto & data_type : argument_types) - LOG_ERROR(getLogger("RowNumGroupArraySorte"), "xxx arg type: {}", data_type->getName()); - return std::make_shared(argument_types[0], argument_types[1], parameters); + if (argument_types.size() != 1 || !typeid_cast(argument_types[0].get())) + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, " {} Nees only one tuple argument", name); + return std::make_shared(argument_types[0], parameters); } void registerAggregateFunctionRowNumGroup(DB::AggregateFunctionFactory & factory) diff --git a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp index 04f1776416e0..8637d987228b 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp @@ -16,12 +16,14 @@ */ #include "GroupLimitRelParser.h" +#include #include #include #include #include #include #include +#include #include #include #include @@ -34,7 +36,6 @@ #include #include #include -#include namespace DB::ErrorCodes { @@ -238,22 +239,6 @@ void AggregateGroupLimitRelParser::prePrejectionForAggregateArguments() "aggregate_data_tuple", aggregate_tuple_column_name); - DB::DataTypes order_tuple_types; - Strings order_tuple_names; - DB::ActionsDAG::NodeRawConstPtrs order_tuple_nodes; - for (const auto & sort_field : win_rel_def->sorts()) - { - if (sort_field.expr().has_selection()) - { - auto col_pos = sort_field.expr().selection().direct_reference().struct_field().field(); - const auto & col = input_header.getByPosition(col_pos); - order_tuple_types.push_back(col.type); - order_tuple_names.push_back(col.name); - order_tuple_nodes.push_back(projection_actions->getInputs()[col_pos]); - } - } - build_tuple(order_tuple_types, order_tuple_names, order_tuple_nodes, "order_tuple", order_tuple_column_name); - projection_actions->removeUnusedActions(required_column_names); LOG_DEBUG( getLogger("AggregateGroupLimitRelParser"), @@ -267,33 +252,48 @@ void AggregateGroupLimitRelParser::prePrejectionForAggregateArguments() } -DB::Array AggregateGroupLimitRelParser::parseSortDirections(const google::protobuf::RepeatedPtrField & sort_fields) +String AggregateGroupLimitRelParser::parseSortDirections(const google::protobuf::RepeatedPtrField & sort_fields) { DB::Array directions; + static const std::unordered_map order_directions + = {{1, " asc nulls first"}, {2, " asc nulls last"}, {3, " desc nulls first"}, {4, " desc nulls last"}}; + size_t n = 0; + DB::WriteBufferFromOwnString ostr; for (const auto & sort_field : sort_fields) { - auto sort_order = magic_enum::enum_cast(sort_field.direction()); - if (!sort_order.has_value()) + auto it = order_directions.find(sort_field.direction()); + if (it == order_directions.end()) throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow sort direction: {}", sort_field.direction()); - directions.emplace_back(std::string(magic_enum::enum_name(sort_order.value()))); + if (!sort_field.expr().has_selection()) + { + throw DB::Exception( + DB::ErrorCodes::BAD_ARGUMENTS, "Sort field must be a column reference. but got {}", sort_field.DebugString()); + } + auto ref = sort_field.expr().selection().direct_reference().struct_field().field(); + const auto & col_name = input_header.getByPosition(ref).name; + if (n) + ostr << String(","); + // the col_name may contain '#' which can may ch fail to parse. + ostr << "`" << col_name << "`" << it->second; + n += 1; } - return directions; + LOG_DEBUG(getLogger("AggregateGroupLimitRelParser"), "Order by clasue: {}", ostr.str()); + return ostr.str(); } DB::AggregateDescription AggregateGroupLimitRelParser::buildAggregateDescription() { DB::AggregateDescription agg_desc; agg_desc.column_name = aggregate_tuple_column_name; - agg_desc.argument_names = {aggregate_tuple_column_name, order_tuple_column_name}; + agg_desc.argument_names = {aggregate_tuple_column_name}; DB::Array parameters; parameters.push_back(static_cast(limit)); auto sort_directions = parseSortDirections(win_rel_def->sorts()); - parameters.insert(parameters.end(), sort_directions.begin(), sort_directions.end()); + parameters.push_back(sort_directions); auto header = current_plan->getCurrentHeader(); DB::DataTypes arg_types; arg_types.push_back(header.getByName(aggregate_tuple_column_name).type); - arg_types.push_back(header.getByName(order_tuple_column_name).type); DB::AggregateFunctionProperties properties; agg_desc.function = getAggregateFunction(aggregate_function_name, arg_types, properties, parameters); diff --git a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h index 5561fe448141..b8c71c819b91 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h +++ b/cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.h @@ -75,14 +75,13 @@ class AggregateGroupLimitRelParser : public RelParser // DB::Block output_header; DB::Names aggregate_grouping_keys; String aggregate_tuple_column_name; - String order_tuple_column_name; String getAggregateFunctionName(const String & window_function_name); void prePrejectionForAggregateArguments(); void addGroupLmitAggregationStep(); - DB::Array parseSortDirections(const google::protobuf::RepeatedPtrField & sort_fields); + String parseSortDirections(const google::protobuf::RepeatedPtrField & sort_fields); DB::AggregateDescription buildAggregateDescription(); void postProjectionForExplodingArrays();