From 9b5ef10f5a95cd9761ef8e84598c7d21f969b09d Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Wed, 20 Nov 2024 11:46:37 +0800 Subject: [PATCH] Fix sort expression (#2271) ### What problem does this PR solve? _Briefly describe what this PR aims to solve. Include background context that will help reviewers understand the purpose of the PR._ ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Python SDK impacted, Need to update PyPI --------- Signed-off-by: jinhai Signed-off-by: Jin Hai --- .../local_infinity/query_builder.py | 16 +-- .../infinity_embedded/local_infinity/table.py | 8 +- .../infinity/remote_thrift/query_builder.py | 108 ++++++++++-------- .../infinity/remote_thrift/table.py | 8 +- 4 files changed, 70 insertions(+), 70 deletions(-) diff --git a/python/infinity_embedded/local_infinity/query_builder.py b/python/infinity_embedded/local_infinity/query_builder.py index 2a21c2b9bf..6f12c0a6b2 100644 --- a/python/infinity_embedded/local_infinity/query_builder.py +++ b/python/infinity_embedded/local_infinity/query_builder.py @@ -10,7 +10,7 @@ from pyarrow import Table from sqlglot import condition, maybe_parse -from infinity_embedded.common import VEC, SparseVector, InfinityException +from infinity_embedded.common import VEC, SparseVector, InfinityException, SortType from infinity_embedded.embedded_infinity_ext import * from infinity_embedded.local_infinity.types import logic_type_to_dtype, make_match_tensor_expr from infinity_embedded.local_infinity.utils import traverse_conditions, parse_expr @@ -477,7 +477,7 @@ def highlight(self, columns: Optional[list]) -> InfinityLocalQueryBuilder: self._highlight = highlight_list return self - def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityLocalQueryBuilder: + def sort(self, order_by_expr_list: Optional[List[list[str, SortType]]]) -> InfinityLocalQueryBuilder: sort_list: List[WrapOrderByExpr] = [] for order_by_expr in order_by_expr_list: if isinstance(order_by_expr[0], str): @@ -491,7 +491,7 @@ def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityL parsed_expr = WrapParsedExpr(ParsedExprType.kColumn) parsed_expr.column_expr = column_expr - order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1]) + order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1] == SortType.Asc) sort_list.append(order_by_expr) case "_row_id": func_expr = WrapFunctionExpr() @@ -502,7 +502,7 @@ def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityL parsed_expr = WrapParsedExpr(expr_type) parsed_expr.function_expr = func_expr - order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1]) + order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1] == SortType.Asc) sort_list.append(order_by_expr) case "_score": func_expr = WrapFunctionExpr() @@ -513,7 +513,7 @@ def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityL parsed_expr = WrapParsedExpr(expr_type) parsed_expr.function_expr = func_expr - order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1]) + order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1] == SortType.Asc) sort_list.append(order_by_expr) case "_similarity": func_expr = WrapFunctionExpr() @@ -524,7 +524,7 @@ def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityL parsed_expr = WrapParsedExpr(expr_type) parsed_expr.function_expr = func_expr - order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1]) + order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1] == SortType.Asc) sort_list.append(order_by_expr) case "_distance": func_expr = WrapFunctionExpr() @@ -535,11 +535,11 @@ def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityL parsed_expr = WrapParsedExpr(expr_type) parsed_expr.function_expr = func_expr - order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1]) + order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1] == SortType.Asc) sort_list.append(order_by_expr) case _: parsed_expr = parse_expr(maybe_parse(order_by_expr[0])) - order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1]) + order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1] == SortType.Asc) sort_list.append(order_by_expr) self._sort = sort_list diff --git a/python/infinity_embedded/local_infinity/table.py b/python/infinity_embedded/local_infinity/table.py index 114662942f..b6a1ac35ea 100644 --- a/python/infinity_embedded/local_infinity/table.py +++ b/python/infinity_embedded/local_infinity/table.py @@ -381,14 +381,10 @@ def offset(self, offset: Optional[int]): def sort(self, order_by_expr_list: Optional[List[list[str, SortType]]]): for order_by_expr in order_by_expr_list: if len(order_by_expr) != 2: - raise InfinityException(ErrorCode.INVALID_PARAMETER, + raise InfinityException(ErrorCode.INVALID_PARAMETER_VALUE, "order_by_expr_list must be a list of [column_name, sort_type]") if order_by_expr[1] not in [SortType.Asc, SortType.Desc]: - raise InfinityException(ErrorCode.INVALID_PARAMETER, "sort_type must be SortType.Asc or SortType.Desc") - if order_by_expr[1] == SortType.Asc: - order_by_expr[1] = True - else: - order_by_expr[1] = False + raise InfinityException(ErrorCode.INVALID_PARAMETER_VALUE, "sort_type must be SortType.Asc or SortType.Desc") self.query_builder.sort(order_by_expr_list) return self diff --git a/python/infinity_sdk/infinity/remote_thrift/query_builder.py b/python/infinity_sdk/infinity/remote_thrift/query_builder.py index ffdebb91f2..cb7dbb8ad0 100644 --- a/python/infinity_sdk/infinity/remote_thrift/query_builder.py +++ b/python/infinity_sdk/infinity/remote_thrift/query_builder.py @@ -24,7 +24,7 @@ from pyarrow import Table from sqlglot import condition, maybe_parse -from infinity.common import VEC, SparseVector, InfinityException +from infinity.common import VEC, SparseVector, InfinityException, SortType from infinity.errors import ErrorCode from infinity.remote_thrift.infinity_thrift_rpc.ttypes import * from infinity.remote_thrift.types import ( @@ -39,15 +39,15 @@ class Query(ABC): def __init__( - self, - columns: Optional[List[ParsedExpr]], - highlight: Optional[List[ParsedExpr]], - search: Optional[SearchExpr], - filter: Optional[ParsedExpr], - groupby: Optional[List[ParsedExpr]], - limit: Optional[ParsedExpr], - offset: Optional[ParsedExpr], - sort: Optional[List[OrderByExpr]], + self, + columns: Optional[List[ParsedExpr]], + highlight: Optional[List[ParsedExpr]], + search: Optional[SearchExpr], + filter: Optional[ParsedExpr], + groupby: Optional[List[ParsedExpr]], + limit: Optional[ParsedExpr], + offset: Optional[ParsedExpr], + sort: Optional[List[OrderByExpr]], ): self.columns = columns self.highlight = highlight @@ -61,16 +61,16 @@ def __init__( class ExplainQuery(Query): def __init__( - self, - columns: Optional[List[ParsedExpr]], - highlight: Optional[List[ParsedExpr]], - search: Optional[SearchExpr], - filter: Optional[ParsedExpr], - groupby: Optional[List[ParsedExpr]], - limit: Optional[ParsedExpr], - offset: Optional[ParsedExpr], - sort: Optional[List[OrderByExpr]], - explain_type: Optional[ExplainType], + self, + columns: Optional[List[ParsedExpr]], + highlight: Optional[List[ParsedExpr]], + search: Optional[SearchExpr], + filter: Optional[ParsedExpr], + groupby: Optional[List[ParsedExpr]], + limit: Optional[ParsedExpr], + offset: Optional[ParsedExpr], + sort: Optional[List[OrderByExpr]], + explain_type: Optional[ExplainType], ): super().__init__(columns, highlight, search, filter, groupby, limit, offset, sort) self.explain_type = explain_type @@ -99,13 +99,13 @@ def reset(self): self._sort = None def match_dense( - self, - vector_column_name: str, - embedding_data: VEC, - embedding_data_type: str, - distance_type: str, - topn: int, - knn_params: {} = None, + self, + vector_column_name: str, + embedding_data: VEC, + embedding_data_type: str, + distance_type: str, + topn: int, + knn_params: {} = None, ) -> InfinityThriftQueryBuilder: if self._search is None: self._search = SearchExpr() @@ -134,7 +134,8 @@ def match_dense( if embedding_data_type == "bit": if len(embedding_data) % 8 != 0: raise InfinityException( - ErrorCode.INVALID_EMBEDDING_DATA_TYPE, f"Embeddings with data bit must have dimension of times of 8!" + ErrorCode.INVALID_EMBEDDING_DATA_TYPE, + f"Embeddings with data bit must have dimension of times of 8!" ) else: new_embedding_data = [] @@ -186,7 +187,8 @@ def match_dense( elem_type = ElementType.ElementBFloat16 data.bf16_array_value = embedding_data else: - raise InfinityException(ErrorCode.INVALID_EMBEDDING_DATA_TYPE, f"Invalid embedding {embedding_data[0]} type") + raise InfinityException(ErrorCode.INVALID_EMBEDDING_DATA_TYPE, + f"Invalid embedding {embedding_data[0]} type") dist_type = KnnDistanceType.L2 if distance_type == "l2": @@ -223,12 +225,12 @@ def match_dense( return self def match_sparse( - self, - vector_column_name: str, - sparse_data: SparseVector | dict, - metric_type: str, - topn: int, - opt_params: Optional[dict] = None, + self, + vector_column_name: str, + sparse_data: SparseVector | dict, + metric_type: str, + topn: int, + opt_params: Optional[dict] = None, ) -> InfinityThriftQueryBuilder: if self._search is None: self._search = SearchExpr() @@ -243,7 +245,7 @@ def match_sparse( return self def match_text( - self, fields: str, matching_text: str, topn: int, extra_options: Optional[dict] + self, fields: str, matching_text: str, topn: int, extra_options: Optional[dict] ) -> InfinityThriftQueryBuilder: if self._search is None: self._search = SearchExpr() @@ -262,12 +264,12 @@ def match_text( return self def match_tensor( - self, - column_name: str, - query_data: VEC, - query_data_type: str, - topn: int, - extra_option: Optional[dict] = None, + self, + column_name: str, + query_data: VEC, + query_data_type: str, + topn: int, + extra_option: Optional[dict] = None, ) -> InfinityThriftQueryBuilder: if self._search is None: self._search = SearchExpr() @@ -382,7 +384,7 @@ def highlight(self, columns: Optional[list]) -> InfinityThriftQueryBuilder: self._highlight = highlight_list return self - def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityThriftQueryBuilder: + def sort(self, order_by_expr_list: Optional[List[list[str, SortType]]]) -> InfinityThriftQueryBuilder: sort_list: List[OrderByExpr] = [] for order_by_expr in order_by_expr_list: if isinstance(order_by_expr[0], str): @@ -393,35 +395,41 @@ def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityT column_expr = ColumnExpr(star=True, column_name=[]) expr_type = ParsedExprType(column_expr=column_expr) parsed_expr = ParsedExpr(type=expr_type) - order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1]) + order_by_flag: bool = order_by_expr[1] == SortType.Asc + order_by_expr = OrderByExpr(expr=parsed_expr, asc=order_by_flag) sort_list.append(order_by_expr) case "_row_id": func_expr = FunctionExpr(function_name="row_id", arguments=[]) expr_type = ParsedExprType(function_expr=func_expr) parsed_expr = ParsedExpr(type=expr_type) - order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1]) + order_by_flag: bool = order_by_expr[1] == SortType.Asc + order_by_expr = OrderByExpr(expr=parsed_expr, asc=order_by_flag) sort_list.append(order_by_expr) case "_score": func_expr = FunctionExpr(function_name="score", arguments=[]) expr_type = ParsedExprType(function_expr=func_expr) parsed_expr = ParsedExpr(type=expr_type) - order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1]) + order_by_flag: bool = order_by_expr[1] == SortType.Asc + order_by_expr = OrderByExpr(expr=parsed_expr, asc=order_by_flag) sort_list.append(order_by_expr) case "_similarity": func_expr = FunctionExpr(function_name="similarity", arguments=[]) expr_type = ParsedExprType(function_expr=func_expr) parsed_expr = ParsedExpr(type=expr_type) - order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1]) + order_by_flag: bool = order_by_expr[1] == SortType.Asc + order_by_expr = OrderByExpr(expr=parsed_expr, asc=order_by_flag) sort_list.append(order_by_expr) case "_distance": func_expr = FunctionExpr(function_name="distance", arguments=[]) expr_type = ParsedExprType(function_expr=func_expr) parsed_expr = ParsedExpr(type=expr_type) - order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1]) + order_by_flag: bool = order_by_expr[1] == SortType.Asc + order_by_expr = OrderByExpr(expr=parsed_expr, asc=order_by_flag) sort_list.append(order_by_expr) case _: parsed_expr = parse_expr(maybe_parse(order_by_expr[0])) - sort_list.append(OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])) + order_by_flag: bool = order_by_expr[1] == SortType.Asc + sort_list.append(OrderByExpr(expr=parsed_expr, asc=order_by_flag)) self._sort = sort_list return self @@ -463,7 +471,7 @@ def explain(self, explain_type=ExplainType.Physical) -> Any: groupby=self._groupby, limit=self._limit, offset=self._offset, - sort = self._sort, + sort=self._sort, explain_type=explain_type, ) return self._table._explain_query(query) diff --git a/python/infinity_sdk/infinity/remote_thrift/table.py b/python/infinity_sdk/infinity/remote_thrift/table.py index ab29f04377..284341daf6 100644 --- a/python/infinity_sdk/infinity/remote_thrift/table.py +++ b/python/infinity_sdk/infinity/remote_thrift/table.py @@ -397,14 +397,10 @@ def offset(self, offset: Optional[int]): def sort(self, order_by_expr_list: Optional[List[list[str, SortType]]]): for order_by_expr in order_by_expr_list: if len(order_by_expr) != 2: - raise InfinityException(ErrorCode.INVALID_PARAMETER, + raise InfinityException(ErrorCode.INVALID_PARAMETER_VALUE, "order_by_expr_list must be a list of [column_name, sort_type]") if order_by_expr[1] not in [SortType.Asc, SortType.Desc]: - raise InfinityException(ErrorCode.INVALID_PARAMETER, "sort_type must be SortType.Asc or SortType.Desc") - if order_by_expr[1] == SortType.Asc: - order_by_expr[1] = True - else: - order_by_expr[1] = False + raise InfinityException(ErrorCode.INVALID_PARAMETER_VALUE, "sort_type must be SortType.Asc or SortType.Desc") self.query_builder.sort(order_by_expr_list) return self